├── examples ├── pytest │ ├── snapshots │ │ ├── __init__.py │ │ ├── snap_test_demo │ │ │ ├── test_file 1.txt │ │ │ ├── test_multiple_files 1.txt │ │ │ └── test_multiple_files 2.txt │ │ └── snap_test_demo.py │ └── test_demo.py ├── django_project │ ├── lists │ │ ├── __init__.py │ │ ├── migrations │ │ │ ├── __init__.py │ │ │ └── 0001_initial.py │ │ ├── snapshots │ │ │ ├── __init__.py │ │ │ └── snap_tests.py │ │ ├── apps.py │ │ ├── views.py │ │ ├── models.py │ │ ├── templates │ │ │ └── home.html │ │ └── tests.py │ ├── django_project │ │ ├── __init__.py │ │ ├── snapshots │ │ │ ├── __init__.py │ │ │ └── snap_tests.py │ │ ├── tests.py │ │ ├── wsgi.py │ │ ├── urls.py │ │ └── settings.py │ └── manage.py └── unittest │ ├── snapshots │ ├── __init__.py │ └── snap_test_demo.py │ └── test_demo.py ├── tests ├── conftest.py ├── test_module.py ├── test_sorted_dict.py ├── test_formatter.py ├── test_pytest.py └── test_snapshot_test.py ├── snapshottest ├── snapshot.py ├── __init__.py ├── error.py ├── generic_repr.py ├── sorted_dict.py ├── formatter.py ├── diff.py ├── nose.py ├── django.py ├── reporting.py ├── file.py ├── pytest.py ├── unittest.py ├── formatters.py └── module.py ├── setup.cfg ├── MANIFEST.in ├── .travis.yml ├── tox.ini ├── Makefile ├── .gitignore ├── LICENSE ├── setup.py ├── CHANGELOG.md └── README.md /examples/pytest/snapshots/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/django_project/lists/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/unittest/snapshots/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/django_project/django_project/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/django_project/lists/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/django_project/lists/snapshots/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | pytest_plugins = "pytester" 2 | -------------------------------------------------------------------------------- /examples/django_project/django_project/snapshots/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/pytest/snapshots/snap_test_demo/test_file 1.txt: -------------------------------------------------------------------------------- 1 | Hello, world! -------------------------------------------------------------------------------- /examples/pytest/snapshots/snap_test_demo/test_multiple_files 1.txt: -------------------------------------------------------------------------------- 1 | Hello, world 1! -------------------------------------------------------------------------------- /examples/pytest/snapshots/snap_test_demo/test_multiple_files 2.txt: -------------------------------------------------------------------------------- 1 | Hello, world 2! -------------------------------------------------------------------------------- /snapshottest/snapshot.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | 4 | class Snapshot(OrderedDict): 5 | pass 6 | -------------------------------------------------------------------------------- /examples/django_project/lists/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class ListsConfig(AppConfig): 5 | name = "lists" 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | 3 | [flake8] 4 | exclude = snapshots .tox venv 5 | max-line-length = 88 6 | extend-ignore = E203 7 | 8 | [tool:pytest] 9 | addopts = --cov snapshottest 10 | testpaths = tests 11 | -------------------------------------------------------------------------------- /examples/django_project/lists/views.py: -------------------------------------------------------------------------------- 1 | from django.shortcuts import render 2 | from lists.models import List 3 | 4 | 5 | def home_page(request): 6 | lists = List.objects.all() 7 | return render(request, "home.html", {"lists": lists}) 8 | -------------------------------------------------------------------------------- /snapshottest/__init__.py: -------------------------------------------------------------------------------- 1 | from .snapshot import Snapshot 2 | from .generic_repr import GenericRepr 3 | from .module import assert_match_snapshot 4 | from .unittest import TestCase 5 | 6 | 7 | __all__ = ["Snapshot", "GenericRepr", "assert_match_snapshot", "TestCase"] 8 | -------------------------------------------------------------------------------- /examples/django_project/lists/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class List(models.Model): 5 | name = models.CharField(max_length=200) 6 | description = models.CharField(max_length=200) 7 | 8 | def __str__(self): 9 | return self.name 10 | -------------------------------------------------------------------------------- /examples/django_project/lists/templates/home.html: -------------------------------------------------------------------------------- 1 | 2 | Lists 3 | 4 | 5 | {% for list in lists %} 6 | 7 | {% endfor %} 8 |
{{forloop.counter}}: {{ list.name }}
: 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /examples/unittest/snapshots/snap_test_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['TestDemo::test_api_me 1'] = { 11 | 'url': '/me' 12 | } 13 | -------------------------------------------------------------------------------- /examples/django_project/django_project/snapshots/snap_tests.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['TestDemo::test_api_me 1'] = { 11 | 'url': '/me' 12 | } 13 | -------------------------------------------------------------------------------- /examples/django_project/lists/tests.py: -------------------------------------------------------------------------------- 1 | from lists.models import List 2 | from snapshottest.django import TestCase 3 | 4 | 5 | class ListTest(TestCase): 6 | def test_uses_home_template(self): 7 | List.objects.create(name="test") 8 | response = self.client.get("/") 9 | self.assertMatchSnapshot(response.content.decode()) 10 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include CHANGELOG.md 4 | 5 | recursive-exclude * __pycache__ 6 | recursive-exclude * *.py[co] 7 | recursive-exclude * .DS_Store 8 | 9 | include Makefile 10 | include tox.ini 11 | recursive-include examples *.html 12 | recursive-include examples *.py 13 | recursive-include examples *.txt 14 | recursive-include tests *.py 15 | -------------------------------------------------------------------------------- /snapshottest/error.py: -------------------------------------------------------------------------------- 1 | class SnapshotError(Exception): 2 | pass 3 | 4 | 5 | class SnapshotNotFound(SnapshotError): 6 | def __init__(self, module, test_name): 7 | super(SnapshotNotFound, self).__init__( 8 | "Snapshot '{snapshot_id!s}' not found in {snapshot_file!s}".format( 9 | snapshot_id=test_name, snapshot_file=module.filepath 10 | ) 11 | ) 12 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: false 3 | python: 4 | - 3.5 5 | - 3.6 6 | - 3.7 7 | - 3.8 8 | cache: pip 9 | install: make install 10 | script: make test 11 | after_success: 12 | - coveralls 13 | matrix: 14 | fast_finish: true 15 | include: 16 | - name: lint 17 | python: '3.8' 18 | install: make install-tools 19 | script: make lint 20 | - name: format 21 | python: '3.8' 22 | install: make install-tools 23 | script: make format 24 | -------------------------------------------------------------------------------- /examples/django_project/django_project/tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from snapshottest.django import SimpleTestCase 4 | 5 | 6 | def api_client_get(url): 7 | return { 8 | "url": url, 9 | } 10 | 11 | 12 | class TestDemo(SimpleTestCase): 13 | def test_api_me(self): 14 | my_api_response = api_client_get("/me") 15 | self.assertMatchSnapshot(my_api_response) 16 | 17 | 18 | if __name__ == "__main__": 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /examples/unittest/test_demo.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import snapshottest 3 | 4 | 5 | def api_client_get(url): 6 | return { 7 | "url": url, 8 | } 9 | 10 | 11 | class TestDemo(snapshottest.TestCase): 12 | def setUp(self): 13 | pass 14 | 15 | def test_api_me(self): 16 | my_api_response = api_client_get("/me") 17 | self.assertMatchSnapshot(my_api_response) 18 | 19 | 20 | if __name__ == "__main__": 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /examples/django_project/django_project/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for django_project project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/1.11/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_project.settings") 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /examples/django_project/lists/snapshots/snap_tests.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['ListTest::test_uses_home_template 1'] = ''' 11 | Lists 12 | 13 | 14 | 15 | 16 | 17 |
1: test
: 18 | 19 | 20 | 21 | ''' 22 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = 3 | format 4 | lint 5 | py{35,36,37,38} 6 | 7 | [testenv] 8 | usedevelop = True 9 | extras = test 10 | commands = 11 | make test 12 | whitelist_externals = 13 | make 14 | bash 15 | passenv = 16 | CONTINUOUS_INTEGRATION 17 | 18 | [testenv:format] 19 | basepython = python3 20 | skip_install = True 21 | commands = 22 | make install-tools 23 | make format 24 | 25 | [testenv:lint] 26 | basepython = python3 27 | skip_install = True 28 | commands = 29 | make install-tools 30 | make lint 31 | -------------------------------------------------------------------------------- /snapshottest/generic_repr.py: -------------------------------------------------------------------------------- 1 | class GenericRepr(object): 2 | def __init__(self, representation): 3 | self.representation = representation 4 | 5 | def __repr__(self): 6 | return "GenericRepr({})".format(repr(self.representation)) 7 | 8 | def __eq__(self, other): 9 | return ( 10 | isinstance(other, GenericRepr) 11 | and self.representation == other.representation 12 | ) 13 | 14 | def __hash__(self): 15 | return hash(self.representation) 16 | 17 | @staticmethod 18 | def from_value(value): 19 | representation = repr(value) 20 | # Remove the hex id, if found. 21 | representation = representation.replace(hex(id(value)), "0x100000000") 22 | return GenericRepr(representation) 23 | -------------------------------------------------------------------------------- /examples/django_project/django_project/urls.py: -------------------------------------------------------------------------------- 1 | """django_project URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/1.11/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.conf.urls import url, include 14 | 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls')) 15 | """ 16 | from django.conf.urls import url 17 | from lists import views 18 | 19 | 20 | urlpatterns = [ 21 | url(r"^$", views.home_page, name="home"), 22 | ] 23 | -------------------------------------------------------------------------------- /examples/django_project/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | if __name__ == "__main__": 6 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_project.settings") 7 | try: 8 | from django.core.management import execute_from_command_line 9 | except ImportError: 10 | # The above import may fail for some other reason. Ensure that the 11 | # issue is really that Django is missing to avoid masking other 12 | # exceptions on Python 2. 13 | try: 14 | import django # noqa: F401 15 | except ImportError: 16 | raise ImportError( 17 | "Couldn't import Django. Are you sure it's installed and " 18 | "available on your PYTHONPATH environment variable? Did you " 19 | "forget to activate a virtual environment?" 20 | ) 21 | raise 22 | execute_from_command_line(sys.argv) 23 | -------------------------------------------------------------------------------- /examples/django_project/lists/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.11.1 on 2017-05-20 04:08 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | 7 | 8 | class Migration(migrations.Migration): 9 | 10 | initial = True 11 | 12 | dependencies = [] 13 | 14 | operations = [ 15 | migrations.CreateModel( 16 | name="List", 17 | fields=[ 18 | ( 19 | "id", 20 | models.AutoField( 21 | auto_created=True, 22 | primary_key=True, 23 | serialize=False, 24 | verbose_name="ID", 25 | ), 26 | ), 27 | ("name", models.CharField(max_length=200)), 28 | ("description", models.CharField(max_length=200)), 29 | ], 30 | ), 31 | ] 32 | -------------------------------------------------------------------------------- /snapshottest/sorted_dict.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | 4 | class SortedDict(OrderedDict): 5 | def __init__(self, values): 6 | super(SortedDict, self).__init__() 7 | 8 | try: 9 | sorted_items = sorted(values.items()) 10 | except TypeError: 11 | # Enums are not sortable 12 | sorted_items = values.items() 13 | for key, value in sorted_items: 14 | if isinstance(value, dict): 15 | self[key] = SortedDict(value) 16 | elif isinstance(value, list): 17 | self[key] = self._sort_list(value) 18 | else: 19 | self[key] = value 20 | 21 | def _sort_list(self, value): 22 | def sort(val): 23 | if isinstance(val, dict): 24 | return SortedDict(val) 25 | elif isinstance(val, list): 26 | return self._sort_list(val) 27 | else: 28 | return val 29 | 30 | return [sort(item) for item in value] 31 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: install test 2 | 3 | .PHONY: develop 4 | develop: install install-tools 5 | 6 | .PHONY: install 7 | install: 8 | pip install -e ".[test]" 9 | 10 | .PHONY: install-tools 11 | install-tools: 12 | pip install flake8 black==20.8b1 13 | 14 | .PHONY: test 15 | test: 16 | # Run Pytest tests (including examples) 17 | py.test --cov=snapshottest tests examples/pytest 18 | 19 | # Run Unittest Example 20 | python examples/unittest/test_demo.py 21 | 22 | # Run nose 23 | nosetests examples/unittest 24 | 25 | # Run Django Example 26 | cd examples/django_project && python manage.py test 27 | 28 | .PHONY: lint 29 | lint: 30 | flake8 31 | 32 | .PHONY: format 33 | format: 34 | black --check setup.py snapshottest tests examples --exclude 'snapshots\/snap_.*.py$$' 35 | 36 | .PHONY: format-fix 37 | format-fix: 38 | black setup.py snapshottest tests examples --exclude 'snapshots\/snap_.*.py$$' 39 | 40 | .PHONY: clean 41 | clean: 42 | rm -rf dist/ build/ 43 | 44 | .PHONY: publish 45 | publish: clean 46 | python3 setup.py sdist bdist_wheel 47 | twine upload dist/* 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io 2 | 3 | ### Python ### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | venv/ 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .pytest_cache/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *,cover 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | 58 | # Sphinx documentation 59 | docs/_build/ 60 | 61 | # PyBuilder 62 | target/ 63 | 64 | # IntelliJ 65 | .idea 66 | 67 | # OS X 68 | .DS_Store 69 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017-Present Syrus Akbary 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /snapshottest/formatter.py: -------------------------------------------------------------------------------- 1 | from .formatters import default_formatters 2 | 3 | 4 | class Formatter(object): 5 | formatters = default_formatters() 6 | 7 | def __init__(self, imports=None): 8 | self.htchar = " " * 4 9 | self.lfchar = "\n" 10 | self.indent = 0 11 | self.imports = imports 12 | 13 | def __call__(self, value, **args): 14 | return self.format(value, self.indent) 15 | 16 | def format(self, value, indent): 17 | formatter = self.get_formatter(value) 18 | for module, import_name in formatter.get_imports(): 19 | self.imports[module].add(import_name) 20 | return formatter.format(value, indent, self) 21 | 22 | def normalize(self, value): 23 | formatter = self.get_formatter(value) 24 | return formatter.normalize(value, self) 25 | 26 | @staticmethod 27 | def get_formatter(value): 28 | for formatter in Formatter.formatters: 29 | if formatter.can_format(value): 30 | return formatter 31 | 32 | # This should never happen as GenericFormatter is registered by default. 33 | raise RuntimeError("No formatter found for value") 34 | 35 | @staticmethod 36 | def register_formatter(formatter): 37 | Formatter.formatters.insert(0, formatter) 38 | -------------------------------------------------------------------------------- /snapshottest/diff.py: -------------------------------------------------------------------------------- 1 | from termcolor import colored 2 | from fastdiff import compare 3 | 4 | from .sorted_dict import SortedDict 5 | from .formatter import Formatter 6 | 7 | 8 | def format_line(line): 9 | line = line.rstrip("\n") 10 | if line.startswith("-"): 11 | return colored(line, "green", attrs=["bold"]) 12 | elif line.startswith("+"): 13 | return colored(line, "red", attrs=["bold"]) 14 | elif line.startswith("?"): 15 | return colored("") + colored(line, "yellow", attrs=["bold"]) 16 | 17 | return colored("") + colored(line, "white", attrs=["dark"]) 18 | 19 | 20 | class PrettyDiff(object): 21 | def __init__(self, obj, snapshottest): 22 | self.pretty = Formatter() 23 | self.snapshottest = snapshottest 24 | if isinstance(obj, dict): 25 | obj = SortedDict(obj) 26 | self.obj = self.pretty(obj) 27 | 28 | def __eq__(self, other): 29 | return isinstance(other, PrettyDiff) and self.obj == other.obj 30 | 31 | def __repr__(self): 32 | return repr(self.obj) 33 | 34 | def get_diff(self, other): 35 | text1 = "Received \n\n" + self.pretty(self.obj) 36 | text2 = "Snapshot \n\n" + self.pretty(other) 37 | 38 | lines = list(compare(text2, text1)) 39 | return [format_line(line) for line in lines] 40 | -------------------------------------------------------------------------------- /examples/pytest/snapshots/snap_test_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import GenericRepr, Snapshot 6 | from snapshottest.file import FileSnapshot 7 | 8 | 9 | snapshots = Snapshot() 10 | 11 | snapshots['test_me_endpoint 1'] = { 12 | 'url': '/me' 13 | } 14 | 15 | snapshots['test_unicode 1'] = 'pépère' 16 | 17 | snapshots['test_object 1'] = GenericRepr('SomeObject(3)') 18 | 19 | snapshots['test_file 1'] = FileSnapshot('snap_test_demo/test_file 1.txt') 20 | 21 | snapshots['test_multiple_files 1'] = FileSnapshot('snap_test_demo/test_multiple_files 1.txt') 22 | 23 | snapshots['test_multiple_files 2'] = FileSnapshot('snap_test_demo/test_multiple_files 2.txt') 24 | 25 | snapshots['test_nested_objects dict'] = { 26 | 'key': GenericRepr('#') 27 | } 28 | 29 | snapshots['test_nested_objects defaultdict'] = { 30 | 'key': [ 31 | GenericRepr('#') 32 | ] 33 | } 34 | 35 | snapshots['test_nested_objects list'] = [ 36 | GenericRepr('#') 37 | ] 38 | 39 | snapshots['test_nested_objects tuple'] = ( 40 | GenericRepr('#') 41 | ,) 42 | 43 | snapshots['test_nested_objects set'] = set([ 44 | GenericRepr('#') 45 | ]) 46 | 47 | snapshots['test_nested_objects frozenset'] = frozenset([ 48 | GenericRepr('#') 49 | ]) 50 | -------------------------------------------------------------------------------- /tests/test_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from snapshottest import Snapshot 4 | from snapshottest.module import SnapshotModule 5 | 6 | 7 | class TestSnapshotModuleLoading(object): 8 | def test_load_not_yet_saved(self, tmpdir): 9 | filepath = tmpdir.join("snap_new.py") 10 | assert not filepath.check() # file does not exist 11 | module = SnapshotModule("tests.snapshots.snap_new", str(filepath)) 12 | snapshots = module.load_snapshots() 13 | assert isinstance(snapshots, Snapshot) 14 | 15 | def test_load_missing_package(self, tmpdir): 16 | filepath = tmpdir.join("snap_import.py") 17 | filepath.write_text("import missing_package\n", "utf-8") 18 | module = SnapshotModule("tests.snapshots.snap_import", str(filepath)) 19 | with pytest.raises(ImportError): 20 | module.load_snapshots() 21 | 22 | def test_load_corrupted_snapshot(self, tmpdir): 23 | filepath = tmpdir.join("snap_error.py") 24 | filepath.write_text("\n", "utf-8") 25 | module = SnapshotModule("tests.snapshots.snap_error", str(filepath)) 26 | with pytest.raises(SyntaxError): 27 | module.load_snapshots() 28 | 29 | def test_save_and_load_when_test_name_with_quotes(self, tmpdir): 30 | filepath = tmpdir.join("snap_error.py") 31 | module = SnapshotModule("tests.snapshots.snap_error", str(filepath)) 32 | module["quo'tes"] = "result" 33 | 34 | module.save() 35 | loaded = module.load_snapshots() 36 | 37 | assert loaded["quo'tes"] == "result" 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from setuptools import setup, find_packages 4 | 5 | with open("README.md") as f: 6 | readme = f.read() 7 | 8 | tests_require = ["pytest>=4.6", "pytest-cov", "nose", "django>=1.10.6"] 9 | 10 | setup( 11 | name="snapshottest", 12 | version="1.0.0a0", 13 | description="Snapshot testing for pytest, unittest, Django, and Nose", 14 | long_description=readme, 15 | long_description_content_type="text/markdown", 16 | author="Syrus Akbary", 17 | author_email="me@syrusakbary.com", 18 | url="https://github.com/syrusakbary/snapshottest", 19 | # custom PyPI classifier for pytest plugins 20 | entry_points={ 21 | "pytest11": [ 22 | "snapshottest = snapshottest.pytest", 23 | ], 24 | "nose.plugins.0.10": ["snapshottest = snapshottest.nose:SnapshotTestPlugin"], 25 | }, 26 | install_requires=["termcolor", "fastdiff>=0.1.4,<1"], 27 | tests_require=tests_require, 28 | extras_require={ 29 | "test": tests_require, 30 | "pytest": [ 31 | "pytest", 32 | ], 33 | "nose": [ 34 | "nose", 35 | ], 36 | }, 37 | requires_python=">=3.5", 38 | classifiers=[ 39 | "Development Status :: 5 - Production/Stable", 40 | "Framework :: Django", 41 | "Framework :: Pytest", 42 | "Intended Audience :: Developers", 43 | "Operating System :: OS Independent", 44 | "Topic :: Software Development :: Libraries", 45 | "Topic :: Software Development :: Testing", 46 | "Topic :: Software Development :: Testing :: Unit", 47 | ], 48 | license="MIT", 49 | packages=find_packages(exclude=("tests",)), 50 | ) 51 | -------------------------------------------------------------------------------- /snapshottest/nose.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from nose.plugins import Plugin 5 | 6 | from .module import SnapshotModule 7 | from .reporting import reporting_lines 8 | from .unittest import TestCase 9 | 10 | log = logging.getLogger("nose.plugins.snapshottest") 11 | 12 | 13 | class SnapshotTestPlugin(Plugin): 14 | name = "snapshottest" 15 | enabled = True 16 | 17 | separator1 = "=" * 70 18 | separator2 = "-" * 70 19 | 20 | def options(self, parser, env=os.environ): 21 | super(SnapshotTestPlugin, self).options(parser, env=env) 22 | parser.add_option( 23 | "--snapshot-update", 24 | action="store_true", 25 | default=False, 26 | dest="snapshot_update", 27 | help="Update the snapshots.", 28 | ) 29 | parser.add_option( 30 | "--snapshot-disable", 31 | action="store_true", 32 | dest="snapshot_disable", 33 | default=False, 34 | help="Disable special SnapshotTest", 35 | ) 36 | 37 | def configure(self, options, conf): 38 | super(SnapshotTestPlugin, self).configure(options, conf) 39 | self.snapshot_update = options.snapshot_update 40 | self.enabled = not options.snapshot_disable 41 | 42 | def wantClass(self, cls): 43 | if issubclass(cls, TestCase): 44 | cls.snapshot_should_update = self.snapshot_update 45 | 46 | def afterContext(self): 47 | if self.snapshot_update: 48 | for module in SnapshotModule.get_modules(): 49 | module.delete_unvisited() 50 | module.save() 51 | 52 | def report(self, stream): 53 | if not SnapshotModule.has_snapshots(): 54 | return 55 | 56 | stream.writeln(self.separator1) 57 | stream.writeln("SnapshotTest summary") 58 | stream.writeln(self.separator2) 59 | for line in reporting_lines("nosetests"): 60 | stream.writeln(line) 61 | stream.writeln(self.separator1) 62 | -------------------------------------------------------------------------------- /snapshottest/django.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase as dTestCase 2 | from django.test import SimpleTestCase as dSimpleTestCase 3 | from django.test.runner import DiscoverRunner 4 | 5 | from snapshottest.reporting import reporting_lines 6 | from .unittest import TestCase as uTestCase 7 | from .module import SnapshotModule 8 | 9 | 10 | class TestRunnerMixin(object): 11 | separator1 = "=" * 70 12 | separator2 = "-" * 70 13 | 14 | def __init__(self, snapshot_update=False, **kwargs): 15 | super(TestRunnerMixin, self).__init__(**kwargs) 16 | uTestCase.snapshot_should_update = snapshot_update 17 | 18 | @classmethod 19 | def add_arguments(cls, parser): 20 | super(TestRunnerMixin, cls).add_arguments(parser) 21 | parser.add_argument( 22 | "--snapshot-update", 23 | default=False, 24 | action="store_true", 25 | dest="snapshot_update", 26 | help="Update the snapshots automatically.", 27 | ) 28 | 29 | def run_tests(self, test_labels, extra_tests=None, **kwargs): 30 | result = super(TestRunnerMixin, self).run_tests( 31 | test_labels=test_labels, extra_tests=extra_tests, **kwargs 32 | ) 33 | self.print_report() 34 | if TestCase.snapshot_should_update: 35 | for module in SnapshotModule.get_modules(): 36 | module.delete_unvisited() 37 | module.save() 38 | 39 | return result 40 | 41 | def print_report(self): 42 | lines = list(reporting_lines("python manage.py test")) 43 | if lines: 44 | print("\n" + self.separator1) 45 | print("SnapshotTest summary") 46 | print(self.separator2) 47 | for line in lines: 48 | print(line) 49 | print(self.separator1) 50 | 51 | 52 | class TestRunner(TestRunnerMixin, DiscoverRunner): 53 | pass 54 | 55 | 56 | class TestCase(uTestCase, dTestCase): 57 | pass 58 | 59 | 60 | class SimpleTestCase(uTestCase, dSimpleTestCase): 61 | pass 62 | -------------------------------------------------------------------------------- /tests/test_sorted_dict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import enum 3 | 4 | import pytest 5 | 6 | from snapshottest.sorted_dict import SortedDict 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "key, value", 11 | [ 12 | ("key1", "value"), 13 | ("key2", 42), 14 | ("key3", ["value"]), 15 | ("key4", [["value"]]), 16 | ("key5", {"key": "value"}), 17 | ("key6", [{"key": "value"}]), 18 | ("key7", {"key": ["value"]}), 19 | ("key8", [{"key": ["value"]}]), 20 | ], 21 | ) 22 | def test_sorted_dict(key, value): 23 | dic = dict([(key, value)]) 24 | assert SortedDict(dic)[key] == value 25 | 26 | 27 | def test_sorted_dict_string_key(): 28 | value = ("key", "value") 29 | dic = dict([value]) 30 | assert SortedDict(dic)[value[0]] == value[1] 31 | 32 | 33 | def test_sorted_dict_int_key(): 34 | value = (0, "value") 35 | dic = dict([value]) 36 | assert SortedDict(dic)[value[0]] == value[1] 37 | 38 | 39 | def test_sorted_dict_intenum(): 40 | class Fruit(enum.IntEnum): 41 | APPLE = 1 42 | ORANGE = 2 43 | 44 | dic = { 45 | Fruit.APPLE: 100, 46 | Fruit.ORANGE: 400, 47 | } 48 | assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE] 49 | assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE] 50 | 51 | 52 | def test_sorted_dict_enum(): 53 | class Fruit(enum.Enum): 54 | APPLE = 1 55 | ORANGE = 2 56 | 57 | dic = { 58 | Fruit.APPLE: 100, 59 | Fruit.ORANGE: 400, 60 | } 61 | assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE] 62 | assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE] 63 | 64 | 65 | def test_sorted_dict_enum_value(): 66 | class Fruit(enum.Enum): 67 | APPLE = 1 68 | ORANGE = 2 69 | 70 | value = ("fruit", Fruit) 71 | dic = dict([value]) 72 | assert SortedDict(dic)[value[0]] == value[1] 73 | 74 | 75 | def test_sorted_dict_enum_key(): 76 | class Fruit(enum.Enum): 77 | APPLE = 1 78 | ORANGE = 2 79 | 80 | value = (Fruit, "fruit") 81 | dic = dict([value]) 82 | assert SortedDict(dic)[value[0]] == value[1] 83 | -------------------------------------------------------------------------------- /snapshottest/reporting.py: -------------------------------------------------------------------------------- 1 | import os 2 | from termcolor import colored 3 | 4 | from .module import SnapshotModule 5 | 6 | 7 | def reporting_lines(testing_cli): 8 | successful_snapshots = SnapshotModule.stats_successful_snapshots() 9 | bold = ["bold"] 10 | if successful_snapshots: 11 | yield (colored("{} snapshots passed", attrs=bold) + ".").format( 12 | successful_snapshots 13 | ) 14 | new_snapshots = SnapshotModule.stats_new_snapshots() 15 | if new_snapshots[0]: 16 | yield ( 17 | colored("{} snapshots written", "green", attrs=bold) + " in {} test suites." 18 | ).format(*new_snapshots) 19 | inspect_str = colored( 20 | "Inspect your code or run with `{} --snapshot-update` to update them.".format( 21 | testing_cli 22 | ), 23 | attrs=["dark"], 24 | ) 25 | failed_snapshots = SnapshotModule.stats_failed_snapshots() 26 | if failed_snapshots[0]: 27 | yield ( 28 | colored("{} snapshots failed", "red", attrs=bold) 29 | + " in {} test suites. " 30 | + inspect_str 31 | ).format(*failed_snapshots) 32 | unvisited_snapshots = SnapshotModule.stats_unvisited_snapshots() 33 | if unvisited_snapshots[0]: 34 | yield ( 35 | colored("{} snapshots deprecated", "yellow", attrs=bold) 36 | + " in {} test suites. " 37 | + inspect_str 38 | ).format(*unvisited_snapshots) 39 | 40 | 41 | def diff_report(left, right): 42 | return [ 43 | "stored snapshot should match the received value", 44 | "", 45 | colored("> ") 46 | + colored("Received value", "red", attrs=["bold"]) 47 | + colored(" does not match ", attrs=["bold"]) 48 | + colored( 49 | "stored snapshot `{}`".format( 50 | left.snapshottest.test_name, 51 | ), 52 | "green", 53 | attrs=["bold"], 54 | ) 55 | + colored(".", attrs=["bold"]), 56 | colored("") 57 | + "> " 58 | + os.path.relpath(left.snapshottest.module.filepath, os.getcwd()), 59 | "", 60 | ] + left.get_diff(right) 61 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 1.0.0a0 4 | 5 | ### BREAKING CHANGES 6 | 7 | - Require Python 3.5+ 8 | 9 | 10 | ## 0.6.0 11 | 12 | ### New features 13 | 14 | - Sort written snapshots. 15 | - Extract Django TestRunner to a mixin so it can be used with alternative base 16 | classes. 17 | 18 | ### Bug fixes 19 | 20 | - Handle SortedDict with keys other than strings. 21 | - Fix error when snapshotting mock Calls objects. 22 | - Fix formatting of `float("nan")` and `float("inf")`. 23 | - Adopt a valid PEP-508 version range for fastdiff. 24 | 25 | ### Other changes 26 | 27 | - Documentation improvements. 28 | - Add tests and changelog to sdist. 29 | 30 | 31 | ## 0.5.1 32 | 33 | ### New features 34 | 35 | - Add named snapshots #18 36 | - Add support for file snapshots #54 37 | - Hide empty output in the Django test runner #60 38 | 39 | ### Bug fixes 40 | 41 | - Fix snapshot-update with nose #19 42 | - Fix comparisons again objects stored as GenericRepr #20 43 | - Fix setting snapshot_should_update on other TestCases #33 44 | - Fix using non-ASCII characters #31 45 | - Fix fail silently when snapshot files are invalid #45 46 | - Remove unused snapshots in the Django runner #43 47 | - Fix python3 multiline unicode snapshots #46 48 | - Fix checks against falsy snapshots #50 49 | - Various fixes in GenericFormatter and collection formatters #82 50 | - Fix pytest parameterize for multiline strings #87 51 | 52 | ### Other changes 53 | 54 | - Documentation improvements. 55 | - Add wheel distribution #11 56 | - Combine build scripts into a Makemile #83 57 | - Update fastdiff version 58 | 59 | 60 | ## 0.5.0 61 | 62 | * Add django support. Closes #1 63 | - Add `snapshottest.django.TestRunner` 64 | - Add `snapshottest.django.TestCase` 65 | - Add `--snapshot-update` to django test command. You can use `python manage.py test --snapshot-update` 66 | * Fix #3, all dicts are sorted before saving it and before comparing. 67 | 68 | ### Breaking changes 69 | 70 | * Drop support for `python 3.3`. Since django don't support that version of python. 71 | * Since all dicts are sorted, this cloud be a breaking change for your tests. 72 | Use the `--snapshot-update` option to update your tests 73 | -------------------------------------------------------------------------------- /tests/test_formatter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pytest 3 | from math import isnan 4 | 5 | from snapshottest.formatter import Formatter 6 | 7 | import unittest.mock 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "text_value, expected", 12 | [ 13 | # basics 14 | ("abc", "'abc'"), 15 | ("", "''"), 16 | ("back\\slash", "'back\\\\slash'"), 17 | # various embedded quotes (single line) 18 | ("""it has "double quotes".""", """'it has "double quotes".'"""), 19 | ("""it's got single quotes""", '''"it's got single quotes"'''), 20 | ("""it's got "both quotes".""", """'it\\'s got "both quotes".'"""), 21 | # multiline gets formatted as triple-quoted 22 | ("one\ntwo\n", "'''one\ntwo\n'''"), 23 | ("three\n'''quotes", '"""three\n\'\'\'quotes"""'), 24 | ("so many\"\"\"\n'''quotes", "'''so many\"\"\"\n\\'\\'\\'quotes'''"), 25 | ], 26 | ) 27 | def test_text_formatting(text_value, expected): 28 | formatter = Formatter() 29 | formatted = formatter(text_value) 30 | assert formatted == expected 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "text_value, expected", 35 | [ 36 | ("encodage précis", "'encodage précis'"), 37 | ("精确的编码", "'精确的编码'"), 38 | # backslash [unicode repr can't just be `"u'{}'".format(value)`] 39 | ("omvänt\\snedstreck", "'omvänt\\\\snedstreck'"), 40 | # multiline 41 | ("ett\ntvå\n", "'''ett\ntvå\n'''"), 42 | ], 43 | ) 44 | def test_non_ascii_text_formatting(text_value, expected): 45 | formatter = Formatter() 46 | formatted = formatter(text_value) 47 | assert formatted == expected 48 | 49 | 50 | # https://github.com/syrusakbary/snapshottest/issues/115 51 | def test_can_normalize_unittest_mock_call_object(): 52 | formatter = Formatter() 53 | print(formatter.normalize(unittest.mock.call(1, 2, 3))) 54 | 55 | 56 | def test_can_normalize_iterator_objects(): 57 | formatter = Formatter() 58 | print(formatter.normalize(x for x in range(3))) 59 | 60 | 61 | @pytest.mark.parametrize( 62 | "value", 63 | [ 64 | 0, 65 | 12.7, 66 | True, 67 | False, 68 | None, 69 | float("-inf"), 70 | float("inf"), 71 | ], 72 | ) 73 | def test_basic_formatting_parsing(value): 74 | formatter = Formatter() 75 | formatted = formatter(value) 76 | parsed = eval(formatted) 77 | assert parsed == value 78 | assert type(parsed) == type(value) 79 | 80 | 81 | def test_formatting_parsing_nan(): 82 | value = float("nan") 83 | 84 | formatter = Formatter() 85 | formatted = formatter(value) 86 | parsed = eval(formatted) 87 | assert isnan(parsed) 88 | -------------------------------------------------------------------------------- /examples/pytest/test_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from collections import defaultdict 3 | 4 | from snapshottest.file import FileSnapshot 5 | 6 | 7 | def api_client_get(url): 8 | return { 9 | "url": url, 10 | } 11 | 12 | 13 | def test_me_endpoint(snapshot): 14 | """Testing the API for /me""" 15 | my_api_response = api_client_get("/me") 16 | snapshot.assert_match(my_api_response) 17 | 18 | 19 | def test_unicode(snapshot): 20 | """Simple test with unicode""" 21 | expect = u"pépère" 22 | snapshot.assert_match(expect) 23 | 24 | 25 | class SomeObject(object): 26 | def __init__(self, value): 27 | self.value = value 28 | 29 | def __repr__(self): 30 | return "SomeObject({})".format(repr(self.value)) 31 | 32 | 33 | def test_object(snapshot): 34 | """ 35 | Test a snapshot with a custom object. The object will be represented in the 36 | snapshot using `snapshottest.GenericRepr`. The snapshot will only match if the 37 | object's repr remains the same. 38 | """ 39 | test_value = SomeObject(3) 40 | snapshot.assert_match(test_value) 41 | 42 | 43 | def test_file(snapshot, tmpdir): 44 | """ 45 | Test a file snapshot. The file contents will be saved in a sub-folder of the 46 | snapshots folder. Useful for large files (e.g. media files) that aren't suitable 47 | for storage as text inside the snap_***.py file. 48 | """ 49 | temp_file = tmpdir.join("example.txt") 50 | temp_file.write("Hello, world!") 51 | snapshot.assert_match(FileSnapshot(str(temp_file))) 52 | 53 | 54 | def test_multiple_files(snapshot, tmpdir): 55 | """ 56 | Each file is stored separately with the snapshot's name inside the module's file 57 | snapshots folder. 58 | """ 59 | temp_file1 = tmpdir.join("example1.txt") 60 | temp_file1.write("Hello, world 1!") 61 | snapshot.assert_match(FileSnapshot(str(temp_file1))) 62 | 63 | temp_file1 = tmpdir.join("example2.txt") 64 | temp_file1.write("Hello, world 2!") 65 | snapshot.assert_match(FileSnapshot(str(temp_file1))) 66 | 67 | 68 | class ObjectWithBadRepr(object): 69 | def __repr__(self): 70 | return "#" 71 | 72 | 73 | def test_nested_objects(snapshot): 74 | obj = ObjectWithBadRepr() 75 | 76 | dict_ = {"key": obj} 77 | defaultdict_ = defaultdict(list, [("key", [obj])]) 78 | list_ = [obj] 79 | tuple_ = (obj,) 80 | set_ = set((obj,)) 81 | frozenset_ = frozenset((obj,)) 82 | 83 | snapshot.assert_match(dict_, "dict") 84 | snapshot.assert_match(defaultdict_, "defaultdict") 85 | snapshot.assert_match(list_, "list") 86 | snapshot.assert_match(tuple_, "tuple") 87 | snapshot.assert_match(set_, "set") 88 | snapshot.assert_match(frozenset_, "frozenset") 89 | -------------------------------------------------------------------------------- /snapshottest/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import filecmp 4 | 5 | from .formatter import Formatter 6 | from .formatters import BaseFormatter 7 | 8 | 9 | class FileSnapshot(object): 10 | def __init__(self, path): 11 | """ 12 | Create a file snapshot pointing to the specified `path`. In a snapshot, `path` 13 | is considered to be relative to the test module's "snapshots" folder. (This is 14 | done to prevent ugly path manipulations inside the snapshot file.) 15 | """ 16 | self.path = path 17 | 18 | def __repr__(self): 19 | return "FileSnapshot({})".format(repr(self.path)) 20 | 21 | def __eq__(self, other): 22 | return self.path == other.path 23 | 24 | 25 | class FileSnapshotFormatter(BaseFormatter): 26 | def can_format(self, value): 27 | return isinstance(value, FileSnapshot) 28 | 29 | def store(self, test, value): 30 | """ 31 | Copy the file from the test location to the snapshot location. 32 | 33 | If the original test file has an extension, the snapshot file will 34 | use the same extension. 35 | """ 36 | 37 | file_snapshot_dir = self.get_file_snapshot_dir(test) 38 | if not os.path.exists(file_snapshot_dir): 39 | os.makedirs(file_snapshot_dir, 0o0700) 40 | extension = os.path.splitext(value.path)[1] 41 | snapshot_file = os.path.join(file_snapshot_dir, test.test_name) + extension 42 | shutil.copy(value.path, snapshot_file) 43 | relative_snapshot_filename = os.path.relpath( 44 | snapshot_file, test.module.snapshot_dir 45 | ) 46 | return FileSnapshot(relative_snapshot_filename) 47 | 48 | def get_imports(self): 49 | return (("snapshottest.file", "FileSnapshot"),) 50 | 51 | def format(self, value, indent, formatter): 52 | return repr(value) 53 | 54 | def assert_value_matches_snapshot( 55 | self, test, test_value, snapshot_value, formatter 56 | ): 57 | snapshot_path = os.path.join(test.module.snapshot_dir, snapshot_value.path) 58 | files_identical = filecmp.cmp(test_value.path, snapshot_path, shallow=False) 59 | assert files_identical, "Stored file differs from test file" 60 | 61 | @staticmethod 62 | def get_file_snapshot_dir(test): 63 | """ 64 | Get the directory for storing file snapshots for `test`. 65 | Snapshot files are stored under: 66 | snapshots/snap_/ 67 | Right next to where the snapshot module is stored: 68 | snapshots/snap_.py 69 | """ 70 | return os.path.join(test.module.snapshot_dir, test.module.module) 71 | 72 | 73 | Formatter.register_formatter(FileSnapshotFormatter()) 74 | -------------------------------------------------------------------------------- /tests/test_pytest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from snapshottest.pytest import PyTestSnapshotTest 4 | 5 | 6 | @pytest.fixture 7 | def options(): 8 | return {} 9 | 10 | 11 | @pytest.fixture 12 | def _apply_options(request, monkeypatch, options): 13 | for k, v in options.items(): 14 | monkeypatch.setattr(request.config, k, v, raising=False) 15 | 16 | 17 | @pytest.fixture 18 | def pytest_snapshot_test(request, _apply_options): 19 | return PyTestSnapshotTest(request) 20 | 21 | 22 | class TestPyTestSnapShotTest: 23 | def test_property_test_name(self, pytest_snapshot_test): 24 | pytest_snapshot_test.assert_match("counter") 25 | assert ( 26 | pytest_snapshot_test.test_name 27 | == "TestPyTestSnapShotTest.test_property_test_name 1" 28 | ) 29 | 30 | pytest_snapshot_test.assert_match("named", "named_test") 31 | assert ( 32 | pytest_snapshot_test.test_name 33 | == "TestPyTestSnapShotTest.test_property_test_name named_test" 34 | ) 35 | 36 | pytest_snapshot_test.assert_match("counter") 37 | assert ( 38 | pytest_snapshot_test.test_name 39 | == "TestPyTestSnapShotTest.test_property_test_name 2" 40 | ) 41 | 42 | 43 | def test_pytest_snapshottest_property_test_name(pytest_snapshot_test): 44 | pytest_snapshot_test.assert_match("counter") 45 | assert ( 46 | pytest_snapshot_test.test_name 47 | == "test_pytest_snapshottest_property_test_name 1" 48 | ) 49 | 50 | pytest_snapshot_test.assert_match("named", "named_test") 51 | assert ( 52 | pytest_snapshot_test.test_name 53 | == "test_pytest_snapshottest_property_test_name named_test" 54 | ) 55 | 56 | pytest_snapshot_test.assert_match("counter") 57 | assert ( 58 | pytest_snapshot_test.test_name 59 | == "test_pytest_snapshottest_property_test_name 2" 60 | ) 61 | 62 | 63 | @pytest.mark.parametrize("arg", ["single line string"]) 64 | def test_pytest_snapshottest_property_test_name_parametrize_singleline( 65 | pytest_snapshot_test, arg 66 | ): 67 | pytest_snapshot_test.assert_match("counter") 68 | assert ( 69 | pytest_snapshot_test.test_name 70 | == "test_pytest_snapshottest_property_test_name_parametrize_singleline" 71 | "[single line string] 1" 72 | ) 73 | 74 | 75 | @pytest.mark.parametrize( 76 | "arg", 77 | [ 78 | """ 79 | multi 80 | line 81 | string 82 | """ 83 | ], 84 | ) 85 | def test_pytest_snapshottest_property_test_name_parametrize_multiline( 86 | pytest_snapshot_test, arg 87 | ): 88 | pytest_snapshot_test.assert_match("counter") 89 | assert ( 90 | pytest_snapshot_test.test_name 91 | == "test_pytest_snapshottest_property_test_name_parametrize_multiline" 92 | "[ multi line string ] 1" 93 | ) 94 | -------------------------------------------------------------------------------- /snapshottest/pytest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import re 3 | 4 | from .module import SnapshotModule, SnapshotTest 5 | from .diff import PrettyDiff 6 | from .reporting import reporting_lines, diff_report 7 | 8 | 9 | def pytest_addoption(parser): 10 | group = parser.getgroup("snapshottest") 11 | group.addoption( 12 | "--snapshot-update", 13 | action="store_true", 14 | default=False, 15 | dest="snapshot_update", 16 | help="Update the snapshots.", 17 | ) 18 | group.addoption( 19 | "--snapshot-verbose", 20 | action="store_true", 21 | default=False, 22 | help="Dump diagnostic and progress information.", 23 | ) 24 | 25 | 26 | class PyTestSnapshotTest(SnapshotTest): 27 | def __init__(self, request=None): 28 | self.request = request 29 | super(PyTestSnapshotTest, self).__init__() 30 | 31 | @property 32 | def module(self): 33 | return SnapshotModule.get_module_for_testpath(self.request.node.fspath.strpath) 34 | 35 | @property 36 | def update(self): 37 | return self.request.config.option.snapshot_update 38 | 39 | @property 40 | def test_name(self): 41 | cls_name = getattr(self.request.node.cls, "__name__", "") 42 | flattened_node_name = re.sub( 43 | r"\s+", " ", self.request.node.name.replace(r"\n", " ") 44 | ) 45 | return "{}{} {}".format( 46 | "{}.".format(cls_name) if cls_name else "", 47 | flattened_node_name, 48 | self.curr_snapshot, 49 | ) 50 | 51 | 52 | class SnapshotSession(object): 53 | def __init__(self, config): 54 | self.verbose = config.getoption("snapshot_verbose") 55 | self.config = config 56 | 57 | def display(self, tr): 58 | if not SnapshotModule.has_snapshots(): 59 | return 60 | 61 | tr.write_sep("=", "SnapshotTest summary") 62 | 63 | for line in reporting_lines("pytest"): 64 | tr.write_line(line) 65 | 66 | 67 | def pytest_assertrepr_compare(op, left, right): 68 | if isinstance(left, PrettyDiff) and op == "==": 69 | return diff_report(left, right) 70 | 71 | 72 | @pytest.fixture 73 | def snapshot(request): 74 | with PyTestSnapshotTest(request) as snapshot_test: 75 | yield snapshot_test 76 | 77 | 78 | def pytest_terminal_summary(terminalreporter): 79 | if terminalreporter.config.option.snapshot_update: 80 | for module in SnapshotModule.get_modules(): 81 | module.delete_unvisited() 82 | module.save() 83 | 84 | terminalreporter.config._snapshotsession.display(terminalreporter) 85 | 86 | 87 | # force the other plugins to initialise first 88 | # (fixes issue with capture not being properly initialised) 89 | @pytest.mark.trylast 90 | def pytest_configure(config): 91 | config._snapshotsession = SnapshotSession(config) 92 | # config.pluginmanager.register(bs, "snapshottest") 93 | -------------------------------------------------------------------------------- /examples/django_project/django_project/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for django_project project. 3 | 4 | Generated by 'django-admin startproject' using Django 1.11.1. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/1.11/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/1.11/ref/settings/ 11 | """ 12 | 13 | import os 14 | 15 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = "$5@im(@s1+p9a&ob#1osrq%*-sue%90o6q*cf0)$h@urtql^4@" 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = [] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | "django.contrib.admin", 35 | "django.contrib.auth", 36 | "django.contrib.contenttypes", 37 | "django.contrib.sessions", 38 | "django.contrib.messages", 39 | "django.contrib.staticfiles", 40 | "lists", 41 | ] 42 | 43 | MIDDLEWARE = [ 44 | "django.middleware.security.SecurityMiddleware", 45 | "django.contrib.sessions.middleware.SessionMiddleware", 46 | "django.middleware.common.CommonMiddleware", 47 | "django.middleware.csrf.CsrfViewMiddleware", 48 | "django.contrib.auth.middleware.AuthenticationMiddleware", 49 | "django.contrib.messages.middleware.MessageMiddleware", 50 | "django.middleware.clickjacking.XFrameOptionsMiddleware", 51 | ] 52 | 53 | ROOT_URLCONF = "django_project.urls" 54 | 55 | TEMPLATES = [ 56 | { 57 | "BACKEND": "django.template.backends.django.DjangoTemplates", 58 | "DIRS": [], 59 | "APP_DIRS": True, 60 | "OPTIONS": { 61 | "context_processors": [ 62 | "django.template.context_processors.debug", 63 | "django.template.context_processors.request", 64 | "django.contrib.auth.context_processors.auth", 65 | "django.contrib.messages.context_processors.messages", 66 | ], 67 | }, 68 | }, 69 | ] 70 | 71 | WSGI_APPLICATION = "django_project.wsgi.application" 72 | 73 | TEST_RUNNER = "snapshottest.django.TestRunner" 74 | 75 | # Database 76 | # https://docs.djangoproject.com/en/1.11/ref/settings/#databases 77 | 78 | DATABASES = { 79 | "default": { 80 | "ENGINE": "django.db.backends.sqlite3", 81 | } 82 | } 83 | 84 | 85 | # Password validation 86 | # https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators 87 | 88 | AUTH_PASSWORD_VALIDATORS = [ 89 | { 90 | "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", # noqa: E501 91 | }, 92 | { 93 | "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", 94 | }, 95 | { 96 | "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", 97 | }, 98 | { 99 | "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", 100 | }, 101 | ] 102 | 103 | 104 | # Internationalization 105 | # https://docs.djangoproject.com/en/1.11/topics/i18n/ 106 | 107 | LANGUAGE_CODE = "en-us" 108 | 109 | TIME_ZONE = "UTC" 110 | 111 | USE_I18N = True 112 | 113 | USE_L10N = True 114 | 115 | USE_TZ = True 116 | 117 | 118 | # Static files (CSS, JavaScript, Images) 119 | # https://docs.djangoproject.com/en/1.11/howto/static-files/ 120 | 121 | STATIC_URL = "/static/" 122 | -------------------------------------------------------------------------------- /snapshottest/unittest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import inspect 3 | 4 | from .module import SnapshotModule, SnapshotTest 5 | from .diff import PrettyDiff 6 | from .reporting import diff_report 7 | 8 | 9 | class UnitTestSnapshotTest(SnapshotTest): 10 | def __init__(self, test_class, test_id, test_filepath, should_update, assertEqual): 11 | self.test_class = test_class 12 | self.test_id = test_id 13 | self.test_filepath = test_filepath 14 | self.assertEqual = assertEqual 15 | self.should_update = should_update 16 | super(UnitTestSnapshotTest, self).__init__() 17 | 18 | @property 19 | def module(self): 20 | return SnapshotModule.get_module_for_testpath(self.test_filepath) 21 | 22 | @property 23 | def update(self): 24 | return self.should_update 25 | 26 | def assert_equals(self, value, snapshot): 27 | self.assertEqual(value, snapshot) 28 | 29 | @property 30 | def test_name(self): 31 | class_name = self.test_class.__name__ 32 | test_name = self.test_id.split(".")[-1] 33 | return "{}::{} {}".format(class_name, test_name, self.curr_snapshot) 34 | 35 | 36 | # Inspired by https://gist.github.com/twolfson/13f5f5784f67fd49b245 37 | class TestCase(unittest.TestCase): 38 | 39 | snapshot_should_update = False 40 | 41 | @classmethod 42 | def setUpClass(cls): 43 | """On inherited classes, run our `setUp` method""" 44 | cls._snapshot_tests = [] 45 | cls._snapshot_file = inspect.getfile(cls) 46 | 47 | if cls is not TestCase and cls.setUp is not TestCase.setUp: 48 | orig_setUp = cls.setUp 49 | orig_tearDown = cls.tearDown 50 | 51 | def setUpOverride(self, *args, **kwargs): 52 | TestCase.setUp(self) 53 | return orig_setUp(self, *args, **kwargs) 54 | 55 | def tearDownOverride(self, *args, **kwargs): 56 | TestCase.tearDown(self) 57 | return orig_tearDown(self, *args, **kwargs) 58 | 59 | cls.setUp = setUpOverride 60 | cls.tearDown = tearDownOverride 61 | 62 | super(TestCase, cls).setUpClass() 63 | 64 | def comparePrettyDifs(self, obj1, obj2, msg): 65 | # self 66 | # assert obj1 == obj2 67 | if not (obj1 == obj2): 68 | raise self.failureException("\n".join(diff_report(obj1, obj2))) 69 | # raise self.failureException("DIFF") 70 | 71 | @classmethod 72 | def tearDownClass(cls): 73 | if cls._snapshot_tests: 74 | module = SnapshotModule.get_module_for_testpath(cls._snapshot_file) 75 | module.save() 76 | super(TestCase, cls).tearDownClass() 77 | 78 | def setUp(self): 79 | """Do some custom setup""" 80 | super().setUp() 81 | # print dir(self.__module__) 82 | self.addTypeEqualityFunc(PrettyDiff, self.comparePrettyDifs) 83 | self._snapshot = UnitTestSnapshotTest( 84 | test_class=self.__class__, 85 | test_id=self.id(), 86 | test_filepath=self._snapshot_file, 87 | should_update=self.snapshot_should_update, 88 | assertEqual=self.assertEqual, 89 | ) 90 | self._snapshot_tests.append(self._snapshot) 91 | SnapshotTest._current_tester = self._snapshot 92 | 93 | def tearDown(self): 94 | """Do some custom setup""" 95 | # print dir(self.__module__) 96 | SnapshotTest._current_tester = None 97 | self._snapshot = None 98 | 99 | def assert_match_snapshot(self, value, name=""): 100 | self._snapshot.assert_match(value, name=name) 101 | 102 | assertMatchSnapshot = assert_match_snapshot 103 | -------------------------------------------------------------------------------- /tests/test_snapshot_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from collections import OrderedDict 3 | 4 | from snapshottest.module import SnapshotModule, SnapshotTest 5 | 6 | 7 | class GenericSnapshotTest(SnapshotTest): 8 | """A concrete SnapshotTest implementation for no particular testing framework""" 9 | 10 | def __init__(self, snapshot_module, update=False, current_test_id=None): 11 | self._generic_options = { 12 | "snapshot_module": snapshot_module, 13 | "update": update, 14 | "current_test_id": current_test_id or "test_mocked", 15 | } 16 | super(GenericSnapshotTest, self).__init__() 17 | 18 | @property 19 | def module(self): 20 | return self._generic_options["snapshot_module"] 21 | 22 | @property 23 | def update(self): 24 | return self._generic_options["update"] 25 | 26 | @property 27 | def test_name(self): 28 | return "{} {}".format( 29 | self._generic_options["current_test_id"], self.curr_snapshot 30 | ) 31 | 32 | def reinitialize(self): 33 | """Reset internal state, as though starting a new test run""" 34 | super(GenericSnapshotTest, self).__init__() 35 | 36 | 37 | def assert_snapshot_test_ran(snapshot_test, test_name=None): 38 | test_name = test_name or snapshot_test.test_name 39 | assert test_name in snapshot_test.module.visited_snapshots 40 | 41 | 42 | def assert_snapshot_test_succeeded(snapshot_test, test_name=None): 43 | test_name = test_name or snapshot_test.test_name 44 | assert_snapshot_test_ran(snapshot_test, test_name) 45 | assert test_name not in snapshot_test.module.failed_snapshots 46 | 47 | 48 | def assert_snapshot_test_failed(snapshot_test, test_name=None): 49 | test_name = test_name or snapshot_test.test_name 50 | assert_snapshot_test_ran(snapshot_test, test_name) 51 | assert test_name in snapshot_test.module.failed_snapshots 52 | 53 | 54 | @pytest.fixture(name="snapshot_test") 55 | def fixture_snapshot_test(tmpdir): 56 | filepath = tmpdir.join("snap_mocked.py") 57 | module = SnapshotModule("tests.snapshots.snap_mocked", str(filepath)) 58 | return GenericSnapshotTest(module) 59 | 60 | 61 | SNAPSHOTABLE_VALUES = [ 62 | "abc", 63 | b"abc", 64 | 123, 65 | 123.456, 66 | {"a": 1, "b": 2, "c": 3}, # dict 67 | ["a", "b", "c"], # list 68 | {"a", "b", "c"}, # set 69 | ("a", "b", "c"), # tuple 70 | ("a",), # tuple only have one element 71 | # Falsy values: 72 | None, 73 | False, 74 | "", 75 | b"", 76 | dict(), 77 | list(), 78 | set(), 79 | tuple(), 80 | 0, 81 | 0.0, 82 | # dict subclasses: 83 | # (Make sure snapshots don't just coerce to dict for comparison.) 84 | OrderedDict([("a", 1), ("b", 2), ("c", 3)]), # same items as earlier dict 85 | OrderedDict([("c", 3), ("b", 2), ("a", 1)]), # same items, different order 86 | ] 87 | 88 | 89 | @pytest.mark.parametrize("value", SNAPSHOTABLE_VALUES, ids=repr) 90 | def test_snapshot_matches_itself(snapshot_test, value): 91 | # first run stores the value as the snapshot 92 | snapshot_test.assert_match(value) 93 | assert_snapshot_test_succeeded(snapshot_test) 94 | 95 | # second run should compare stored snapshot and also succeed 96 | snapshot_test.reinitialize() 97 | snapshot_test.assert_match(value) 98 | assert_snapshot_test_succeeded(snapshot_test) 99 | 100 | 101 | @pytest.mark.parametrize( 102 | "value, other_value", 103 | [ 104 | pytest.param( 105 | value, 106 | other_value, 107 | id="snapshot {!r} shouldn't match {!r}".format(value, other_value), 108 | ) 109 | for value in SNAPSHOTABLE_VALUES 110 | for other_value in SNAPSHOTABLE_VALUES 111 | if other_value != value 112 | ], 113 | ) 114 | def test_snapshot_does_not_match_other_values(snapshot_test, value, other_value): 115 | # first run stores the value as the snapshot 116 | snapshot_test.assert_match(value) 117 | assert_snapshot_test_succeeded(snapshot_test) 118 | 119 | # second run tries to match other_value, should fail 120 | snapshot_test.reinitialize() 121 | with pytest.raises(AssertionError): 122 | snapshot_test.assert_match(other_value) 123 | assert_snapshot_test_failed(snapshot_test) 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SnapshotTest [![travis][travis-image]][travis-url] [![pypi][pypi-image]][pypi-url] 2 | 3 | [travis-image]: https://img.shields.io/travis/syrusakbary/snapshottest.svg?style=flat 4 | [travis-url]: https://travis-ci.org/syrusakbary/snapshottest 5 | [pypi-image]: https://img.shields.io/pypi/v/snapshottest.svg?style=flat 6 | [pypi-url]: https://pypi.python.org/pypi/snapshottest 7 | 8 | 9 | Snapshot testing is a way to test your APIs without writing actual test cases. 10 | 11 | 1. A snapshot is a single state of your API, saved in a file. 12 | 2. You have a set of snapshots for your API endpoints. 13 | 3. Once you add a new feature, you can generate *automatically* new snapshots for the updated API. 14 | 15 | ## Installation 16 | 17 | $ pip install snapshottest 18 | 19 | 20 | ## Usage with unittest/nose 21 | 22 | ```python 23 | from snapshottest import TestCase 24 | 25 | class APITestCase(TestCase): 26 | def test_api_me(self): 27 | """Testing the API for /me""" 28 | my_api_response = api.client.get('/me') 29 | self.assertMatchSnapshot(my_api_response) 30 | 31 | # Set custom snapshot name: `gpg_response` 32 | my_gpg_response = api.client.get('/me?gpg_key') 33 | self.assertMatchSnapshot(my_gpg_response, 'gpg_response') 34 | ``` 35 | 36 | If you want to update the snapshots automatically you can use the `nosetests --snapshot-update`. 37 | 38 | Check the [Unittest example](https://github.com/syrusakbary/snapshottest/tree/master/examples/unittest). 39 | 40 | ## Usage with pytest 41 | 42 | ```python 43 | def test_mything(snapshot): 44 | """Testing the API for /me""" 45 | my_api_response = api.client.get('/me') 46 | snapshot.assert_match(my_api_response) 47 | 48 | # Set custom snapshot name: `gpg_response` 49 | my_gpg_response = api.client.get('/me?gpg_key') 50 | snapshot.assert_match(my_gpg_response, 'gpg_response') 51 | ``` 52 | 53 | If you want to update the snapshots automatically you can use the `--snapshot-update` config. 54 | 55 | Check the [Pytest example](https://github.com/syrusakbary/snapshottest/tree/master/examples/pytest). 56 | 57 | ## Usage with django 58 | Add to your settings: 59 | ```python 60 | TEST_RUNNER = 'snapshottest.django.TestRunner' 61 | ``` 62 | To create your snapshottest: 63 | ```python 64 | from snapshottest.django import TestCase 65 | 66 | class APITestCase(TestCase): 67 | def test_api_me(self): 68 | """Testing the API for /me""" 69 | my_api_response = api.client.get('/me') 70 | self.assertMatchSnapshot(my_api_response) 71 | ``` 72 | If you want to update the snapshots automatically you can use the `python manage.py test --snapshot-update`. 73 | Check the [Django example](https://github.com/syrusakbary/snapshottest/tree/master/examples/django_project). 74 | 75 | ## Disabling terminal colors 76 | 77 | Set the environment variable `ANSI_COLORS_DISABLED` (to any value), e.g. 78 | 79 | ANSI_COLORS_DISABLED=1 pytest 80 | 81 | 82 | # Contributing 83 | 84 | After cloning this repo and configuring a virtualenv for snapshottest (optional, but highly recommended), ensure dependencies are installed by running: 85 | 86 | ```sh 87 | make develop 88 | ``` 89 | 90 | After developing, ensure your code is formatted properly by running: 91 | 92 | ```sh 93 | make format-fix 94 | ``` 95 | 96 | and then run the full test suite with: 97 | 98 | ```sh 99 | make lint 100 | # and 101 | make test 102 | ``` 103 | 104 | To test locally on all supported Python versions, you can use 105 | [tox](https://tox.readthedocs.io/): 106 | 107 | ```sh 108 | pip install tox # (if you haven't before) 109 | tox 110 | ``` 111 | 112 | # Notes 113 | 114 | This package is heavily inspired in [jest snapshot testing](https://facebook.github.io/jest/docs/snapshot-testing.html). 115 | 116 | # Reasons to use this package 117 | 118 | > Most of this content is taken from the [Jest snapshot blogpost](https://facebook.github.io/jest/blog/2016/07/27/jest-14.html). 119 | 120 | We want to make it as frictionless as possible to write good tests that are useful. 121 | We observed that when engineers are provided with ready-to-use tools, they end up writing more tests, which in turn results in stable and healthy code bases. 122 | 123 | However engineers frequently spend more time writing a test than the component itself. As a result many people stopped writing tests altogether which eventually led to instabilities. 124 | 125 | A typical snapshot test case for a mobile app renders a UI component, takes a screenshot, then compares it to a reference image stored alongside the test. The test will fail if the two images do not match: either the change is unexpected, or the screenshot needs to be updated to the new version of the UI component. 126 | 127 | 128 | ## Snapshot Testing with SnapshotTest 129 | 130 | A similar approach can be taken when it comes to testing your APIs. 131 | Instead of rendering the graphical UI, which would require building the entire app, you can use a test renderer to quickly generate a serializable value for your API response. 132 | 133 | 134 | ## License 135 | 136 | [MIT License](https://github.com/syrusakbary/snapshottest/blob/master/LICENSE) 137 | 138 | [![coveralls][coveralls-image]][coveralls-url] 139 | 140 | [coveralls-image]: https://coveralls.io/repos/syrusakbary/snapshottest/badge.svg?branch=master&service=github 141 | [coveralls-url]: https://coveralls.io/github/syrusakbary/snapshottest?branch=master 142 | -------------------------------------------------------------------------------- /snapshottest/formatters.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict 3 | 4 | from .sorted_dict import SortedDict 5 | from .generic_repr import GenericRepr 6 | 7 | 8 | class BaseFormatter(object): 9 | def can_format(self, value): 10 | raise NotImplementedError() 11 | 12 | def format(self, value, indent, formatter): 13 | raise NotImplementedError() 14 | 15 | def get_imports(self): 16 | return () 17 | 18 | def assert_value_matches_snapshot( 19 | self, test, test_value, snapshot_value, formatter 20 | ): 21 | test.assert_equals(formatter.normalize(test_value), snapshot_value) 22 | 23 | def store(self, test, value): 24 | return value 25 | 26 | def normalize(self, value, formatter): 27 | return value 28 | 29 | 30 | class TypeFormatter(BaseFormatter): 31 | def __init__(self, types, format_func): 32 | self.types = types 33 | self.format_func = format_func 34 | 35 | def can_format(self, value): 36 | return isinstance(value, self.types) 37 | 38 | def format(self, value, indent, formatter): 39 | return self.format_func(value, indent, formatter) 40 | 41 | 42 | class CollectionFormatter(TypeFormatter): 43 | def normalize(self, value, formatter): 44 | iterator = iter(value.items()) if isinstance(value, dict) else iter(value) 45 | # https://github.com/syrusakbary/snapshottest/issues/115 46 | # Normally we shouldn't need to turn this into a list, but some iterable 47 | # constructors need a list not an iterator (e.g. unittest.mock.call). 48 | return value.__class__([formatter.normalize(item) for item in iterator]) 49 | 50 | 51 | class DefaultDictFormatter(TypeFormatter): 52 | def normalize(self, value, formatter): 53 | return defaultdict( 54 | value.default_factory, (formatter.normalize(item) for item in value.items()) 55 | ) 56 | 57 | 58 | def trepr(s): 59 | text = "\n".join([repr(line).lstrip("u")[1:-1] for line in s.split("\n")]) 60 | quotes, dquotes = "'''", '"""' 61 | if quotes in text: 62 | if dquotes in text: 63 | text = text.replace(quotes, "\\'\\'\\'") 64 | else: 65 | quotes = dquotes 66 | return "%s%s%s" % (quotes, text, quotes) 67 | 68 | 69 | def format_none(value, indent, formatter): 70 | return "None" 71 | 72 | 73 | def format_str(value, indent, formatter): 74 | if "\n" in value: 75 | # Is a multiline string, so we use '''{}''' for the repr 76 | return trepr(value) 77 | 78 | # Snapshots are saved with `from __future__ import unicode_literals`, 79 | # so the `u'...'` repr is unnecessary, even on Python 2 80 | return repr(value).lstrip("u") 81 | 82 | 83 | def format_float(value, indent, formatter): 84 | if math.isinf(value) or math.isnan(value): 85 | return 'float("%s")' % repr(value) 86 | return repr(value) 87 | 88 | 89 | def format_std_type(value, indent, formatter): 90 | return repr(value) 91 | 92 | 93 | def format_dict(value, indent, formatter): 94 | value = SortedDict(value) 95 | items = [ 96 | formatter.lfchar 97 | + formatter.htchar * (indent + 1) 98 | + formatter.format(key, indent) 99 | + ": " 100 | + formatter.format(value[key], indent + 1) 101 | for key in value 102 | ] 103 | return "{%s}" % (",".join(items) + formatter.lfchar + formatter.htchar * indent) 104 | 105 | 106 | def format_list(value, indent, formatter): 107 | return "[%s]" % format_sequence(value, indent, formatter) 108 | 109 | 110 | def format_sequence(value, indent, formatter): 111 | items = [ 112 | formatter.lfchar 113 | + formatter.htchar * (indent + 1) 114 | + formatter.format(item, indent + 1) 115 | for item in value 116 | ] 117 | return ",".join(items) + formatter.lfchar + formatter.htchar * indent 118 | 119 | 120 | def format_tuple(value, indent, formatter): 121 | return "(%s%s" % ( 122 | format_sequence(value, indent, formatter), 123 | ",)" if len(value) == 1 else ")", 124 | ) 125 | 126 | 127 | def format_set(value, indent, formatter): 128 | return "set([%s])" % format_sequence(value, indent, formatter) 129 | 130 | 131 | def format_frozenset(value, indent, formatter): 132 | return "frozenset([%s])" % format_sequence(value, indent, formatter) 133 | 134 | 135 | class GenericFormatter(BaseFormatter): 136 | def can_format(self, value): 137 | return True 138 | 139 | def store(self, test, value): 140 | return GenericRepr.from_value(value) 141 | 142 | def normalize(self, value, formatter): 143 | return GenericRepr.from_value(value) 144 | 145 | def format(self, value, indent, formatter): 146 | if not isinstance(value, GenericRepr): 147 | value = GenericRepr.from_value(value) 148 | return repr(value) 149 | 150 | def get_imports(self): 151 | return [("snapshottest", "GenericRepr")] 152 | 153 | def assert_value_matches_snapshot( 154 | self, test, test_value, snapshot_value, formatter 155 | ): 156 | test_value = GenericRepr.from_value(test_value) 157 | # Assert equality between the representations to provide a nice textual diff. 158 | test.assert_equals(test_value.representation, snapshot_value.representation) 159 | 160 | 161 | def default_formatters(): 162 | return [ 163 | TypeFormatter(type(None), format_none), 164 | DefaultDictFormatter(defaultdict, format_dict), 165 | CollectionFormatter(dict, format_dict), 166 | CollectionFormatter(tuple, format_tuple), 167 | CollectionFormatter(list, format_list), 168 | CollectionFormatter(set, format_set), 169 | CollectionFormatter(frozenset, format_frozenset), 170 | TypeFormatter((str,), format_str), 171 | TypeFormatter((float,), format_float), 172 | TypeFormatter((int, complex, bool, bytes), format_std_type), 173 | GenericFormatter(), 174 | ] 175 | -------------------------------------------------------------------------------- /snapshottest/module.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import errno 3 | import os 4 | import imp 5 | from collections import defaultdict 6 | import logging 7 | 8 | from .snapshot import Snapshot 9 | from .formatter import Formatter 10 | from .error import SnapshotNotFound 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def _escape_quotes(text): 17 | return text.replace("'", "\\'") 18 | 19 | 20 | class SnapshotModule(object): 21 | _snapshot_modules = {} 22 | 23 | def __init__(self, module, filepath): 24 | self._original_snapshot = None 25 | self._snapshots = None 26 | self.module = module 27 | self.filepath = filepath 28 | self.imports = defaultdict(set) 29 | self.visited_snapshots = set() 30 | self.new_snapshots = set() 31 | self.failed_snapshots = set() 32 | self.imports["snapshottest"].add("Snapshot") 33 | 34 | def load_snapshots(self): 35 | try: 36 | source = imp.load_source(self.module, self.filepath) 37 | # except FileNotFoundError: # Python 3 38 | except (IOError, OSError) as err: 39 | if err.errno == errno.ENOENT: 40 | return Snapshot() 41 | else: 42 | raise 43 | else: 44 | assert isinstance(source.snapshots, Snapshot) 45 | return source.snapshots 46 | 47 | def visit(self, snapshot_name): 48 | self.visited_snapshots.add(snapshot_name) 49 | 50 | def delete_unvisited(self): 51 | for unvisited in self.unvisited_snapshots: 52 | del self.snapshots[unvisited] 53 | 54 | @property 55 | def unvisited_snapshots(self): 56 | return set(self.snapshots.keys()) - self.visited_snapshots 57 | 58 | @classmethod 59 | def total_unvisited_snapshots(cls): 60 | unvisited_snapshots = 0 61 | unvisited_modules = 0 62 | for module in cls.get_modules(): 63 | unvisited_snapshot_len = len(module.unvisited_snapshots) 64 | unvisited_snapshots += unvisited_snapshot_len 65 | unvisited_modules += min(unvisited_snapshot_len, 1) 66 | 67 | return unvisited_snapshots, unvisited_modules 68 | 69 | @classmethod 70 | def get_modules(cls): 71 | return SnapshotModule._snapshot_modules.values() 72 | 73 | @classmethod 74 | def stats_for_module(cls, getter): 75 | count_snapshots = 0 76 | count_modules = 0 77 | for module in SnapshotModule._snapshot_modules.values(): 78 | length = getter(module) 79 | count_snapshots += length 80 | count_modules += min(length, 1) 81 | 82 | return count_snapshots, count_modules 83 | 84 | @classmethod 85 | def stats_unvisited_snapshots(cls): 86 | return cls.stats_for_module(lambda module: len(module.unvisited_snapshots)) 87 | 88 | @classmethod 89 | def stats_visited_snapshots(cls): 90 | return cls.stats_for_module(lambda module: len(module.visited_snapshots)) 91 | 92 | @classmethod 93 | def stats_new_snapshots(cls): 94 | return cls.stats_for_module(lambda module: len(module.new_snapshots)) 95 | 96 | @classmethod 97 | def stats_failed_snapshots(cls): 98 | return cls.stats_for_module(lambda module: len(module.failed_snapshots)) 99 | 100 | @classmethod 101 | def stats_successful_snapshots(cls): 102 | stats_visited = cls.stats_visited_snapshots() 103 | stats_failed = cls.stats_failed_snapshots() 104 | return stats_visited[0] - stats_failed[0] 105 | 106 | @classmethod 107 | def has_snapshots(cls): 108 | return cls.stats_visited_snapshots()[0] > 0 109 | 110 | @property 111 | def original_snapshot(self): 112 | if not self._original_snapshot: 113 | self._original_snapshot = self.load_snapshots() 114 | return self._original_snapshot 115 | 116 | @property 117 | def snapshots(self): 118 | if not self._snapshots: 119 | self._snapshots = Snapshot(self.original_snapshot) 120 | return self._snapshots 121 | 122 | def __getitem__(self, test_name): 123 | try: 124 | return self.snapshots[test_name] 125 | except KeyError: 126 | raise SnapshotNotFound(self, test_name) 127 | 128 | def __setitem__(self, key, value): 129 | if key not in self.snapshots: 130 | # It's a new test 131 | self.new_snapshots.add(key) 132 | self.snapshots[key] = value 133 | 134 | def mark_failed(self, key): 135 | return self.failed_snapshots.add(key) 136 | 137 | @property 138 | def snapshot_dir(self): 139 | return os.path.dirname(self.filepath) 140 | 141 | def save(self): 142 | if self.original_snapshot == self.snapshots: 143 | # If there are no changes, we do nothing 144 | return 145 | 146 | # Create the snapshot dir in case doesn't exist 147 | try: 148 | os.makedirs(self.snapshot_dir, 0o0700) 149 | except (IOError, OSError): 150 | pass 151 | 152 | # Create __init__.py in case doesn't exist 153 | open(os.path.join(self.snapshot_dir, "__init__.py"), "a").close() 154 | 155 | pretty = Formatter(self.imports) 156 | 157 | with codecs.open(self.filepath, "w", encoding="utf-8") as snapshot_file: 158 | snapshots_declarations = [ 159 | """snapshots['{}'] = {}""".format( 160 | _escape_quotes(key), pretty(self.snapshots[key]) 161 | ) 162 | for key in sorted(self.snapshots.keys()) 163 | ] 164 | 165 | imports = "\n".join( 166 | [ 167 | "from {} import {}".format( 168 | module, ", ".join(sorted(module_imports)) 169 | ) 170 | for module, module_imports in sorted(self.imports.items()) 171 | ] 172 | ) 173 | snapshot_file.write( 174 | """# -*- coding: utf-8 -*- 175 | # snapshottest: v1 - https://goo.gl/zC4yUc 176 | from __future__ import unicode_literals 177 | 178 | {} 179 | 180 | 181 | snapshots = Snapshot() 182 | 183 | {} 184 | """.format( 185 | imports, "\n\n".join(snapshots_declarations) 186 | ) 187 | ) 188 | 189 | @classmethod 190 | def get_module_for_testpath(cls, test_filepath): 191 | if test_filepath not in cls._snapshot_modules: 192 | dirname = os.path.dirname(test_filepath) 193 | snapshot_dir = os.path.join(dirname, "snapshots") 194 | 195 | snapshot_basename = "snap_{}.py".format( 196 | os.path.splitext(os.path.basename(test_filepath))[0] 197 | ) 198 | snapshot_filename = os.path.join(snapshot_dir, snapshot_basename) 199 | snapshot_module = "{}".format(os.path.splitext(snapshot_basename)[0]) 200 | 201 | cls._snapshot_modules[test_filepath] = SnapshotModule( 202 | snapshot_module, snapshot_filename 203 | ) 204 | 205 | return cls._snapshot_modules[test_filepath] 206 | 207 | 208 | class SnapshotTest(object): 209 | _current_tester = None 210 | 211 | def __init__(self): 212 | self.curr_snapshot = "" 213 | self.snapshot_counter = 1 214 | 215 | @property 216 | def module(self): 217 | raise NotImplementedError("module property needs to be implemented") 218 | 219 | @property 220 | def update(self): 221 | return False 222 | 223 | @property 224 | def test_name(self): 225 | raise NotImplementedError("test_name property needs to be implemented") 226 | 227 | def __enter__(self): 228 | SnapshotTest._current_tester = self 229 | return self 230 | 231 | def __exit__(self, type, value, tb): 232 | self.save_changes() 233 | SnapshotTest._current_tester = None 234 | 235 | def visit(self): 236 | self.module.visit(self.test_name) 237 | 238 | def fail(self): 239 | self.module.mark_failed(self.test_name) 240 | 241 | def store(self, data): 242 | formatter = Formatter.get_formatter(data) 243 | data = formatter.store(self, data) 244 | self.module[self.test_name] = data 245 | 246 | def assert_value_matches_snapshot(self, test_value, snapshot_value): 247 | formatter = Formatter.get_formatter(test_value) 248 | formatter.assert_value_matches_snapshot( 249 | self, test_value, snapshot_value, Formatter() 250 | ) 251 | 252 | def assert_equals(self, value, snapshot): 253 | assert value == snapshot 254 | 255 | def assert_match(self, value, name=""): 256 | self.curr_snapshot = name or self.snapshot_counter 257 | self.visit() 258 | if self.update: 259 | self.store(value) 260 | else: 261 | try: 262 | prev_snapshot = self.module[self.test_name] 263 | except SnapshotNotFound: 264 | self.store(value) # first time this test has been seen 265 | else: 266 | try: 267 | self.assert_value_matches_snapshot(value, prev_snapshot) 268 | except AssertionError: 269 | self.fail() 270 | raise 271 | 272 | if not name: 273 | self.snapshot_counter += 1 274 | 275 | def save_changes(self): 276 | self.module.save() 277 | 278 | 279 | def assert_match_snapshot(value, name=""): 280 | if not SnapshotTest._current_tester: 281 | raise Exception( 282 | "You need to use assert_match_snapshot in the SnapshotTest context." 283 | ) 284 | 285 | SnapshotTest._current_tester.assert_match(value, name) 286 | --------------------------------------------------------------------------------