├── examples
├── pytest
│ ├── snapshots
│ │ ├── __init__.py
│ │ ├── snap_test_demo
│ │ │ ├── test_file 1.txt
│ │ │ ├── test_multiple_files 1.txt
│ │ │ └── test_multiple_files 2.txt
│ │ └── snap_test_demo.py
│ └── test_demo.py
├── django_project
│ ├── lists
│ │ ├── __init__.py
│ │ ├── migrations
│ │ │ ├── __init__.py
│ │ │ └── 0001_initial.py
│ │ ├── snapshots
│ │ │ ├── __init__.py
│ │ │ └── snap_tests.py
│ │ ├── apps.py
│ │ ├── views.py
│ │ ├── models.py
│ │ ├── templates
│ │ │ └── home.html
│ │ └── tests.py
│ ├── django_project
│ │ ├── __init__.py
│ │ ├── snapshots
│ │ │ ├── __init__.py
│ │ │ └── snap_tests.py
│ │ ├── tests.py
│ │ ├── wsgi.py
│ │ ├── urls.py
│ │ └── settings.py
│ └── manage.py
└── unittest
│ ├── snapshots
│ ├── __init__.py
│ └── snap_test_demo.py
│ └── test_demo.py
├── tests
├── conftest.py
├── test_module.py
├── test_sorted_dict.py
├── test_formatter.py
├── test_pytest.py
└── test_snapshot_test.py
├── snapshottest
├── snapshot.py
├── __init__.py
├── error.py
├── generic_repr.py
├── sorted_dict.py
├── formatter.py
├── diff.py
├── nose.py
├── django.py
├── reporting.py
├── file.py
├── pytest.py
├── unittest.py
├── formatters.py
└── module.py
├── setup.cfg
├── MANIFEST.in
├── .travis.yml
├── tox.ini
├── Makefile
├── .gitignore
├── LICENSE
├── setup.py
├── CHANGELOG.md
└── README.md
/examples/pytest/snapshots/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/django_project/lists/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/unittest/snapshots/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/django_project/django_project/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/django_project/lists/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/django_project/lists/snapshots/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | pytest_plugins = "pytester"
2 |
--------------------------------------------------------------------------------
/examples/django_project/django_project/snapshots/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/pytest/snapshots/snap_test_demo/test_file 1.txt:
--------------------------------------------------------------------------------
1 | Hello, world!
--------------------------------------------------------------------------------
/examples/pytest/snapshots/snap_test_demo/test_multiple_files 1.txt:
--------------------------------------------------------------------------------
1 | Hello, world 1!
--------------------------------------------------------------------------------
/examples/pytest/snapshots/snap_test_demo/test_multiple_files 2.txt:
--------------------------------------------------------------------------------
1 | Hello, world 2!
--------------------------------------------------------------------------------
/snapshottest/snapshot.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 |
4 | class Snapshot(OrderedDict):
5 | pass
6 |
--------------------------------------------------------------------------------
/examples/django_project/lists/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 |
3 |
4 | class ListsConfig(AppConfig):
5 | name = "lists"
6 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bdist_wheel]
2 |
3 | [flake8]
4 | exclude = snapshots .tox venv
5 | max-line-length = 88
6 | extend-ignore = E203
7 |
8 | [tool:pytest]
9 | addopts = --cov snapshottest
10 | testpaths = tests
11 |
--------------------------------------------------------------------------------
/examples/django_project/lists/views.py:
--------------------------------------------------------------------------------
1 | from django.shortcuts import render
2 | from lists.models import List
3 |
4 |
5 | def home_page(request):
6 | lists = List.objects.all()
7 | return render(request, "home.html", {"lists": lists})
8 |
--------------------------------------------------------------------------------
/snapshottest/__init__.py:
--------------------------------------------------------------------------------
1 | from .snapshot import Snapshot
2 | from .generic_repr import GenericRepr
3 | from .module import assert_match_snapshot
4 | from .unittest import TestCase
5 |
6 |
7 | __all__ = ["Snapshot", "GenericRepr", "assert_match_snapshot", "TestCase"]
8 |
--------------------------------------------------------------------------------
/examples/django_project/lists/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class List(models.Model):
5 | name = models.CharField(max_length=200)
6 | description = models.CharField(max_length=200)
7 |
8 | def __str__(self):
9 | return self.name
10 |
--------------------------------------------------------------------------------
/examples/django_project/lists/templates/home.html:
--------------------------------------------------------------------------------
1 |
2 |
Lists
3 |
4 |
5 | {% for list in lists %}
6 | | {{forloop.counter}}: {{ list.name }} |
7 | {% endfor %}
8 |
:
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/examples/unittest/snapshots/snap_test_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # snapshottest: v1 - https://goo.gl/zC4yUc
3 | from __future__ import unicode_literals
4 |
5 | from snapshottest import Snapshot
6 |
7 |
8 | snapshots = Snapshot()
9 |
10 | snapshots['TestDemo::test_api_me 1'] = {
11 | 'url': '/me'
12 | }
13 |
--------------------------------------------------------------------------------
/examples/django_project/django_project/snapshots/snap_tests.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # snapshottest: v1 - https://goo.gl/zC4yUc
3 | from __future__ import unicode_literals
4 |
5 | from snapshottest import Snapshot
6 |
7 |
8 | snapshots = Snapshot()
9 |
10 | snapshots['TestDemo::test_api_me 1'] = {
11 | 'url': '/me'
12 | }
13 |
--------------------------------------------------------------------------------
/examples/django_project/lists/tests.py:
--------------------------------------------------------------------------------
1 | from lists.models import List
2 | from snapshottest.django import TestCase
3 |
4 |
5 | class ListTest(TestCase):
6 | def test_uses_home_template(self):
7 | List.objects.create(name="test")
8 | response = self.client.get("/")
9 | self.assertMatchSnapshot(response.content.decode())
10 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include README.md
3 | include CHANGELOG.md
4 |
5 | recursive-exclude * __pycache__
6 | recursive-exclude * *.py[co]
7 | recursive-exclude * .DS_Store
8 |
9 | include Makefile
10 | include tox.ini
11 | recursive-include examples *.html
12 | recursive-include examples *.py
13 | recursive-include examples *.txt
14 | recursive-include tests *.py
15 |
--------------------------------------------------------------------------------
/snapshottest/error.py:
--------------------------------------------------------------------------------
1 | class SnapshotError(Exception):
2 | pass
3 |
4 |
5 | class SnapshotNotFound(SnapshotError):
6 | def __init__(self, module, test_name):
7 | super(SnapshotNotFound, self).__init__(
8 | "Snapshot '{snapshot_id!s}' not found in {snapshot_file!s}".format(
9 | snapshot_id=test_name, snapshot_file=module.filepath
10 | )
11 | )
12 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | sudo: false
3 | python:
4 | - 3.5
5 | - 3.6
6 | - 3.7
7 | - 3.8
8 | cache: pip
9 | install: make install
10 | script: make test
11 | after_success:
12 | - coveralls
13 | matrix:
14 | fast_finish: true
15 | include:
16 | - name: lint
17 | python: '3.8'
18 | install: make install-tools
19 | script: make lint
20 | - name: format
21 | python: '3.8'
22 | install: make install-tools
23 | script: make format
24 |
--------------------------------------------------------------------------------
/examples/django_project/django_project/tests.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from snapshottest.django import SimpleTestCase
4 |
5 |
6 | def api_client_get(url):
7 | return {
8 | "url": url,
9 | }
10 |
11 |
12 | class TestDemo(SimpleTestCase):
13 | def test_api_me(self):
14 | my_api_response = api_client_get("/me")
15 | self.assertMatchSnapshot(my_api_response)
16 |
17 |
18 | if __name__ == "__main__":
19 | unittest.main()
20 |
--------------------------------------------------------------------------------
/examples/unittest/test_demo.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import snapshottest
3 |
4 |
5 | def api_client_get(url):
6 | return {
7 | "url": url,
8 | }
9 |
10 |
11 | class TestDemo(snapshottest.TestCase):
12 | def setUp(self):
13 | pass
14 |
15 | def test_api_me(self):
16 | my_api_response = api_client_get("/me")
17 | self.assertMatchSnapshot(my_api_response)
18 |
19 |
20 | if __name__ == "__main__":
21 | unittest.main()
22 |
--------------------------------------------------------------------------------
/examples/django_project/django_project/wsgi.py:
--------------------------------------------------------------------------------
1 | """
2 | WSGI config for django_project project.
3 |
4 | It exposes the WSGI callable as a module-level variable named ``application``.
5 |
6 | For more information on this file, see
7 | https://docs.djangoproject.com/en/1.11/howto/deployment/wsgi/
8 | """
9 |
10 | import os
11 |
12 | from django.core.wsgi import get_wsgi_application
13 |
14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_project.settings")
15 |
16 | application = get_wsgi_application()
17 |
--------------------------------------------------------------------------------
/examples/django_project/lists/snapshots/snap_tests.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # snapshottest: v1 - https://goo.gl/zC4yUc
3 | from __future__ import unicode_literals
4 |
5 | from snapshottest import Snapshot
6 |
7 |
8 | snapshots = Snapshot()
9 |
10 | snapshots['ListTest::test_uses_home_template 1'] = '''
11 | Lists
12 |
13 |
14 |
15 | | 1: test |
16 |
17 |
:
18 |
19 |
20 |
21 | '''
22 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist =
3 | format
4 | lint
5 | py{35,36,37,38}
6 |
7 | [testenv]
8 | usedevelop = True
9 | extras = test
10 | commands =
11 | make test
12 | whitelist_externals =
13 | make
14 | bash
15 | passenv =
16 | CONTINUOUS_INTEGRATION
17 |
18 | [testenv:format]
19 | basepython = python3
20 | skip_install = True
21 | commands =
22 | make install-tools
23 | make format
24 |
25 | [testenv:lint]
26 | basepython = python3
27 | skip_install = True
28 | commands =
29 | make install-tools
30 | make lint
31 |
--------------------------------------------------------------------------------
/snapshottest/generic_repr.py:
--------------------------------------------------------------------------------
1 | class GenericRepr(object):
2 | def __init__(self, representation):
3 | self.representation = representation
4 |
5 | def __repr__(self):
6 | return "GenericRepr({})".format(repr(self.representation))
7 |
8 | def __eq__(self, other):
9 | return (
10 | isinstance(other, GenericRepr)
11 | and self.representation == other.representation
12 | )
13 |
14 | def __hash__(self):
15 | return hash(self.representation)
16 |
17 | @staticmethod
18 | def from_value(value):
19 | representation = repr(value)
20 | # Remove the hex id, if found.
21 | representation = representation.replace(hex(id(value)), "0x100000000")
22 | return GenericRepr(representation)
23 |
--------------------------------------------------------------------------------
/examples/django_project/django_project/urls.py:
--------------------------------------------------------------------------------
1 | """django_project URL Configuration
2 |
3 | The `urlpatterns` list routes URLs to views. For more information please see:
4 | https://docs.djangoproject.com/en/1.11/topics/http/urls/
5 | Examples:
6 | Function views
7 | 1. Add an import: from my_app import views
8 | 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
9 | Class-based views
10 | 1. Add an import: from other_app.views import Home
11 | 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
12 | Including another URLconf
13 | 1. Import the include() function: from django.conf.urls import url, include
14 | 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
15 | """
16 | from django.conf.urls import url
17 | from lists import views
18 |
19 |
20 | urlpatterns = [
21 | url(r"^$", views.home_page, name="home"),
22 | ]
23 |
--------------------------------------------------------------------------------
/examples/django_project/manage.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import sys
4 |
5 | if __name__ == "__main__":
6 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_project.settings")
7 | try:
8 | from django.core.management import execute_from_command_line
9 | except ImportError:
10 | # The above import may fail for some other reason. Ensure that the
11 | # issue is really that Django is missing to avoid masking other
12 | # exceptions on Python 2.
13 | try:
14 | import django # noqa: F401
15 | except ImportError:
16 | raise ImportError(
17 | "Couldn't import Django. Are you sure it's installed and "
18 | "available on your PYTHONPATH environment variable? Did you "
19 | "forget to activate a virtual environment?"
20 | )
21 | raise
22 | execute_from_command_line(sys.argv)
23 |
--------------------------------------------------------------------------------
/examples/django_project/lists/migrations/0001_initial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by Django 1.11.1 on 2017-05-20 04:08
3 | from __future__ import unicode_literals
4 |
5 | from django.db import migrations, models
6 |
7 |
8 | class Migration(migrations.Migration):
9 |
10 | initial = True
11 |
12 | dependencies = []
13 |
14 | operations = [
15 | migrations.CreateModel(
16 | name="List",
17 | fields=[
18 | (
19 | "id",
20 | models.AutoField(
21 | auto_created=True,
22 | primary_key=True,
23 | serialize=False,
24 | verbose_name="ID",
25 | ),
26 | ),
27 | ("name", models.CharField(max_length=200)),
28 | ("description", models.CharField(max_length=200)),
29 | ],
30 | ),
31 | ]
32 |
--------------------------------------------------------------------------------
/snapshottest/sorted_dict.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 |
4 | class SortedDict(OrderedDict):
5 | def __init__(self, values):
6 | super(SortedDict, self).__init__()
7 |
8 | try:
9 | sorted_items = sorted(values.items())
10 | except TypeError:
11 | # Enums are not sortable
12 | sorted_items = values.items()
13 | for key, value in sorted_items:
14 | if isinstance(value, dict):
15 | self[key] = SortedDict(value)
16 | elif isinstance(value, list):
17 | self[key] = self._sort_list(value)
18 | else:
19 | self[key] = value
20 |
21 | def _sort_list(self, value):
22 | def sort(val):
23 | if isinstance(val, dict):
24 | return SortedDict(val)
25 | elif isinstance(val, list):
26 | return self._sort_list(val)
27 | else:
28 | return val
29 |
30 | return [sort(item) for item in value]
31 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | all: install test
2 |
3 | .PHONY: develop
4 | develop: install install-tools
5 |
6 | .PHONY: install
7 | install:
8 | pip install -e ".[test]"
9 |
10 | .PHONY: install-tools
11 | install-tools:
12 | pip install flake8 black==20.8b1
13 |
14 | .PHONY: test
15 | test:
16 | # Run Pytest tests (including examples)
17 | py.test --cov=snapshottest tests examples/pytest
18 |
19 | # Run Unittest Example
20 | python examples/unittest/test_demo.py
21 |
22 | # Run nose
23 | nosetests examples/unittest
24 |
25 | # Run Django Example
26 | cd examples/django_project && python manage.py test
27 |
28 | .PHONY: lint
29 | lint:
30 | flake8
31 |
32 | .PHONY: format
33 | format:
34 | black --check setup.py snapshottest tests examples --exclude 'snapshots\/snap_.*.py$$'
35 |
36 | .PHONY: format-fix
37 | format-fix:
38 | black setup.py snapshottest tests examples --exclude 'snapshots\/snap_.*.py$$'
39 |
40 | .PHONY: clean
41 | clean:
42 | rm -rf dist/ build/
43 |
44 | .PHONY: publish
45 | publish: clean
46 | python3 setup.py sdist bdist_wheel
47 | twine upload dist/*
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.gitignore.io
2 |
3 | ### Python ###
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | env/
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | venv/
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .pytest_cache/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *,cover
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 |
58 | # Sphinx documentation
59 | docs/_build/
60 |
61 | # PyBuilder
62 | target/
63 |
64 | # IntelliJ
65 | .idea
66 |
67 | # OS X
68 | .DS_Store
69 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2017-Present Syrus Akbary
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/snapshottest/formatter.py:
--------------------------------------------------------------------------------
1 | from .formatters import default_formatters
2 |
3 |
4 | class Formatter(object):
5 | formatters = default_formatters()
6 |
7 | def __init__(self, imports=None):
8 | self.htchar = " " * 4
9 | self.lfchar = "\n"
10 | self.indent = 0
11 | self.imports = imports
12 |
13 | def __call__(self, value, **args):
14 | return self.format(value, self.indent)
15 |
16 | def format(self, value, indent):
17 | formatter = self.get_formatter(value)
18 | for module, import_name in formatter.get_imports():
19 | self.imports[module].add(import_name)
20 | return formatter.format(value, indent, self)
21 |
22 | def normalize(self, value):
23 | formatter = self.get_formatter(value)
24 | return formatter.normalize(value, self)
25 |
26 | @staticmethod
27 | def get_formatter(value):
28 | for formatter in Formatter.formatters:
29 | if formatter.can_format(value):
30 | return formatter
31 |
32 | # This should never happen as GenericFormatter is registered by default.
33 | raise RuntimeError("No formatter found for value")
34 |
35 | @staticmethod
36 | def register_formatter(formatter):
37 | Formatter.formatters.insert(0, formatter)
38 |
--------------------------------------------------------------------------------
/snapshottest/diff.py:
--------------------------------------------------------------------------------
1 | from termcolor import colored
2 | from fastdiff import compare
3 |
4 | from .sorted_dict import SortedDict
5 | from .formatter import Formatter
6 |
7 |
8 | def format_line(line):
9 | line = line.rstrip("\n")
10 | if line.startswith("-"):
11 | return colored(line, "green", attrs=["bold"])
12 | elif line.startswith("+"):
13 | return colored(line, "red", attrs=["bold"])
14 | elif line.startswith("?"):
15 | return colored("") + colored(line, "yellow", attrs=["bold"])
16 |
17 | return colored("") + colored(line, "white", attrs=["dark"])
18 |
19 |
20 | class PrettyDiff(object):
21 | def __init__(self, obj, snapshottest):
22 | self.pretty = Formatter()
23 | self.snapshottest = snapshottest
24 | if isinstance(obj, dict):
25 | obj = SortedDict(obj)
26 | self.obj = self.pretty(obj)
27 |
28 | def __eq__(self, other):
29 | return isinstance(other, PrettyDiff) and self.obj == other.obj
30 |
31 | def __repr__(self):
32 | return repr(self.obj)
33 |
34 | def get_diff(self, other):
35 | text1 = "Received \n\n" + self.pretty(self.obj)
36 | text2 = "Snapshot \n\n" + self.pretty(other)
37 |
38 | lines = list(compare(text2, text1))
39 | return [format_line(line) for line in lines]
40 |
--------------------------------------------------------------------------------
/examples/pytest/snapshots/snap_test_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # snapshottest: v1 - https://goo.gl/zC4yUc
3 | from __future__ import unicode_literals
4 |
5 | from snapshottest import GenericRepr, Snapshot
6 | from snapshottest.file import FileSnapshot
7 |
8 |
9 | snapshots = Snapshot()
10 |
11 | snapshots['test_me_endpoint 1'] = {
12 | 'url': '/me'
13 | }
14 |
15 | snapshots['test_unicode 1'] = 'pépère'
16 |
17 | snapshots['test_object 1'] = GenericRepr('SomeObject(3)')
18 |
19 | snapshots['test_file 1'] = FileSnapshot('snap_test_demo/test_file 1.txt')
20 |
21 | snapshots['test_multiple_files 1'] = FileSnapshot('snap_test_demo/test_multiple_files 1.txt')
22 |
23 | snapshots['test_multiple_files 2'] = FileSnapshot('snap_test_demo/test_multiple_files 2.txt')
24 |
25 | snapshots['test_nested_objects dict'] = {
26 | 'key': GenericRepr('#')
27 | }
28 |
29 | snapshots['test_nested_objects defaultdict'] = {
30 | 'key': [
31 | GenericRepr('#')
32 | ]
33 | }
34 |
35 | snapshots['test_nested_objects list'] = [
36 | GenericRepr('#')
37 | ]
38 |
39 | snapshots['test_nested_objects tuple'] = (
40 | GenericRepr('#')
41 | ,)
42 |
43 | snapshots['test_nested_objects set'] = set([
44 | GenericRepr('#')
45 | ])
46 |
47 | snapshots['test_nested_objects frozenset'] = frozenset([
48 | GenericRepr('#')
49 | ])
50 |
--------------------------------------------------------------------------------
/tests/test_module.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from snapshottest import Snapshot
4 | from snapshottest.module import SnapshotModule
5 |
6 |
7 | class TestSnapshotModuleLoading(object):
8 | def test_load_not_yet_saved(self, tmpdir):
9 | filepath = tmpdir.join("snap_new.py")
10 | assert not filepath.check() # file does not exist
11 | module = SnapshotModule("tests.snapshots.snap_new", str(filepath))
12 | snapshots = module.load_snapshots()
13 | assert isinstance(snapshots, Snapshot)
14 |
15 | def test_load_missing_package(self, tmpdir):
16 | filepath = tmpdir.join("snap_import.py")
17 | filepath.write_text("import missing_package\n", "utf-8")
18 | module = SnapshotModule("tests.snapshots.snap_import", str(filepath))
19 | with pytest.raises(ImportError):
20 | module.load_snapshots()
21 |
22 | def test_load_corrupted_snapshot(self, tmpdir):
23 | filepath = tmpdir.join("snap_error.py")
24 | filepath.write_text("\n", "utf-8")
25 | module = SnapshotModule("tests.snapshots.snap_error", str(filepath))
26 | with pytest.raises(SyntaxError):
27 | module.load_snapshots()
28 |
29 | def test_save_and_load_when_test_name_with_quotes(self, tmpdir):
30 | filepath = tmpdir.join("snap_error.py")
31 | module = SnapshotModule("tests.snapshots.snap_error", str(filepath))
32 | module["quo'tes"] = "result"
33 |
34 | module.save()
35 | loaded = module.load_snapshots()
36 |
37 | assert loaded["quo'tes"] == "result"
38 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from setuptools import setup, find_packages
4 |
5 | with open("README.md") as f:
6 | readme = f.read()
7 |
8 | tests_require = ["pytest>=4.6", "pytest-cov", "nose", "django>=1.10.6"]
9 |
10 | setup(
11 | name="snapshottest",
12 | version="1.0.0a0",
13 | description="Snapshot testing for pytest, unittest, Django, and Nose",
14 | long_description=readme,
15 | long_description_content_type="text/markdown",
16 | author="Syrus Akbary",
17 | author_email="me@syrusakbary.com",
18 | url="https://github.com/syrusakbary/snapshottest",
19 | # custom PyPI classifier for pytest plugins
20 | entry_points={
21 | "pytest11": [
22 | "snapshottest = snapshottest.pytest",
23 | ],
24 | "nose.plugins.0.10": ["snapshottest = snapshottest.nose:SnapshotTestPlugin"],
25 | },
26 | install_requires=["termcolor", "fastdiff>=0.1.4,<1"],
27 | tests_require=tests_require,
28 | extras_require={
29 | "test": tests_require,
30 | "pytest": [
31 | "pytest",
32 | ],
33 | "nose": [
34 | "nose",
35 | ],
36 | },
37 | requires_python=">=3.5",
38 | classifiers=[
39 | "Development Status :: 5 - Production/Stable",
40 | "Framework :: Django",
41 | "Framework :: Pytest",
42 | "Intended Audience :: Developers",
43 | "Operating System :: OS Independent",
44 | "Topic :: Software Development :: Libraries",
45 | "Topic :: Software Development :: Testing",
46 | "Topic :: Software Development :: Testing :: Unit",
47 | ],
48 | license="MIT",
49 | packages=find_packages(exclude=("tests",)),
50 | )
51 |
--------------------------------------------------------------------------------
/snapshottest/nose.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | from nose.plugins import Plugin
5 |
6 | from .module import SnapshotModule
7 | from .reporting import reporting_lines
8 | from .unittest import TestCase
9 |
10 | log = logging.getLogger("nose.plugins.snapshottest")
11 |
12 |
13 | class SnapshotTestPlugin(Plugin):
14 | name = "snapshottest"
15 | enabled = True
16 |
17 | separator1 = "=" * 70
18 | separator2 = "-" * 70
19 |
20 | def options(self, parser, env=os.environ):
21 | super(SnapshotTestPlugin, self).options(parser, env=env)
22 | parser.add_option(
23 | "--snapshot-update",
24 | action="store_true",
25 | default=False,
26 | dest="snapshot_update",
27 | help="Update the snapshots.",
28 | )
29 | parser.add_option(
30 | "--snapshot-disable",
31 | action="store_true",
32 | dest="snapshot_disable",
33 | default=False,
34 | help="Disable special SnapshotTest",
35 | )
36 |
37 | def configure(self, options, conf):
38 | super(SnapshotTestPlugin, self).configure(options, conf)
39 | self.snapshot_update = options.snapshot_update
40 | self.enabled = not options.snapshot_disable
41 |
42 | def wantClass(self, cls):
43 | if issubclass(cls, TestCase):
44 | cls.snapshot_should_update = self.snapshot_update
45 |
46 | def afterContext(self):
47 | if self.snapshot_update:
48 | for module in SnapshotModule.get_modules():
49 | module.delete_unvisited()
50 | module.save()
51 |
52 | def report(self, stream):
53 | if not SnapshotModule.has_snapshots():
54 | return
55 |
56 | stream.writeln(self.separator1)
57 | stream.writeln("SnapshotTest summary")
58 | stream.writeln(self.separator2)
59 | for line in reporting_lines("nosetests"):
60 | stream.writeln(line)
61 | stream.writeln(self.separator1)
62 |
--------------------------------------------------------------------------------
/snapshottest/django.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase as dTestCase
2 | from django.test import SimpleTestCase as dSimpleTestCase
3 | from django.test.runner import DiscoverRunner
4 |
5 | from snapshottest.reporting import reporting_lines
6 | from .unittest import TestCase as uTestCase
7 | from .module import SnapshotModule
8 |
9 |
10 | class TestRunnerMixin(object):
11 | separator1 = "=" * 70
12 | separator2 = "-" * 70
13 |
14 | def __init__(self, snapshot_update=False, **kwargs):
15 | super(TestRunnerMixin, self).__init__(**kwargs)
16 | uTestCase.snapshot_should_update = snapshot_update
17 |
18 | @classmethod
19 | def add_arguments(cls, parser):
20 | super(TestRunnerMixin, cls).add_arguments(parser)
21 | parser.add_argument(
22 | "--snapshot-update",
23 | default=False,
24 | action="store_true",
25 | dest="snapshot_update",
26 | help="Update the snapshots automatically.",
27 | )
28 |
29 | def run_tests(self, test_labels, extra_tests=None, **kwargs):
30 | result = super(TestRunnerMixin, self).run_tests(
31 | test_labels=test_labels, extra_tests=extra_tests, **kwargs
32 | )
33 | self.print_report()
34 | if TestCase.snapshot_should_update:
35 | for module in SnapshotModule.get_modules():
36 | module.delete_unvisited()
37 | module.save()
38 |
39 | return result
40 |
41 | def print_report(self):
42 | lines = list(reporting_lines("python manage.py test"))
43 | if lines:
44 | print("\n" + self.separator1)
45 | print("SnapshotTest summary")
46 | print(self.separator2)
47 | for line in lines:
48 | print(line)
49 | print(self.separator1)
50 |
51 |
52 | class TestRunner(TestRunnerMixin, DiscoverRunner):
53 | pass
54 |
55 |
56 | class TestCase(uTestCase, dTestCase):
57 | pass
58 |
59 |
60 | class SimpleTestCase(uTestCase, dSimpleTestCase):
61 | pass
62 |
--------------------------------------------------------------------------------
/tests/test_sorted_dict.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import enum
3 |
4 | import pytest
5 |
6 | from snapshottest.sorted_dict import SortedDict
7 |
8 |
9 | @pytest.mark.parametrize(
10 | "key, value",
11 | [
12 | ("key1", "value"),
13 | ("key2", 42),
14 | ("key3", ["value"]),
15 | ("key4", [["value"]]),
16 | ("key5", {"key": "value"}),
17 | ("key6", [{"key": "value"}]),
18 | ("key7", {"key": ["value"]}),
19 | ("key8", [{"key": ["value"]}]),
20 | ],
21 | )
22 | def test_sorted_dict(key, value):
23 | dic = dict([(key, value)])
24 | assert SortedDict(dic)[key] == value
25 |
26 |
27 | def test_sorted_dict_string_key():
28 | value = ("key", "value")
29 | dic = dict([value])
30 | assert SortedDict(dic)[value[0]] == value[1]
31 |
32 |
33 | def test_sorted_dict_int_key():
34 | value = (0, "value")
35 | dic = dict([value])
36 | assert SortedDict(dic)[value[0]] == value[1]
37 |
38 |
39 | def test_sorted_dict_intenum():
40 | class Fruit(enum.IntEnum):
41 | APPLE = 1
42 | ORANGE = 2
43 |
44 | dic = {
45 | Fruit.APPLE: 100,
46 | Fruit.ORANGE: 400,
47 | }
48 | assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE]
49 | assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE]
50 |
51 |
52 | def test_sorted_dict_enum():
53 | class Fruit(enum.Enum):
54 | APPLE = 1
55 | ORANGE = 2
56 |
57 | dic = {
58 | Fruit.APPLE: 100,
59 | Fruit.ORANGE: 400,
60 | }
61 | assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE]
62 | assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE]
63 |
64 |
65 | def test_sorted_dict_enum_value():
66 | class Fruit(enum.Enum):
67 | APPLE = 1
68 | ORANGE = 2
69 |
70 | value = ("fruit", Fruit)
71 | dic = dict([value])
72 | assert SortedDict(dic)[value[0]] == value[1]
73 |
74 |
75 | def test_sorted_dict_enum_key():
76 | class Fruit(enum.Enum):
77 | APPLE = 1
78 | ORANGE = 2
79 |
80 | value = (Fruit, "fruit")
81 | dic = dict([value])
82 | assert SortedDict(dic)[value[0]] == value[1]
83 |
--------------------------------------------------------------------------------
/snapshottest/reporting.py:
--------------------------------------------------------------------------------
1 | import os
2 | from termcolor import colored
3 |
4 | from .module import SnapshotModule
5 |
6 |
7 | def reporting_lines(testing_cli):
8 | successful_snapshots = SnapshotModule.stats_successful_snapshots()
9 | bold = ["bold"]
10 | if successful_snapshots:
11 | yield (colored("{} snapshots passed", attrs=bold) + ".").format(
12 | successful_snapshots
13 | )
14 | new_snapshots = SnapshotModule.stats_new_snapshots()
15 | if new_snapshots[0]:
16 | yield (
17 | colored("{} snapshots written", "green", attrs=bold) + " in {} test suites."
18 | ).format(*new_snapshots)
19 | inspect_str = colored(
20 | "Inspect your code or run with `{} --snapshot-update` to update them.".format(
21 | testing_cli
22 | ),
23 | attrs=["dark"],
24 | )
25 | failed_snapshots = SnapshotModule.stats_failed_snapshots()
26 | if failed_snapshots[0]:
27 | yield (
28 | colored("{} snapshots failed", "red", attrs=bold)
29 | + " in {} test suites. "
30 | + inspect_str
31 | ).format(*failed_snapshots)
32 | unvisited_snapshots = SnapshotModule.stats_unvisited_snapshots()
33 | if unvisited_snapshots[0]:
34 | yield (
35 | colored("{} snapshots deprecated", "yellow", attrs=bold)
36 | + " in {} test suites. "
37 | + inspect_str
38 | ).format(*unvisited_snapshots)
39 |
40 |
41 | def diff_report(left, right):
42 | return [
43 | "stored snapshot should match the received value",
44 | "",
45 | colored("> ")
46 | + colored("Received value", "red", attrs=["bold"])
47 | + colored(" does not match ", attrs=["bold"])
48 | + colored(
49 | "stored snapshot `{}`".format(
50 | left.snapshottest.test_name,
51 | ),
52 | "green",
53 | attrs=["bold"],
54 | )
55 | + colored(".", attrs=["bold"]),
56 | colored("")
57 | + "> "
58 | + os.path.relpath(left.snapshottest.module.filepath, os.getcwd()),
59 | "",
60 | ] + left.get_diff(right)
61 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ## 1.0.0a0
4 |
5 | ### BREAKING CHANGES
6 |
7 | - Require Python 3.5+
8 |
9 |
10 | ## 0.6.0
11 |
12 | ### New features
13 |
14 | - Sort written snapshots.
15 | - Extract Django TestRunner to a mixin so it can be used with alternative base
16 | classes.
17 |
18 | ### Bug fixes
19 |
20 | - Handle SortedDict with keys other than strings.
21 | - Fix error when snapshotting mock Calls objects.
22 | - Fix formatting of `float("nan")` and `float("inf")`.
23 | - Adopt a valid PEP-508 version range for fastdiff.
24 |
25 | ### Other changes
26 |
27 | - Documentation improvements.
28 | - Add tests and changelog to sdist.
29 |
30 |
31 | ## 0.5.1
32 |
33 | ### New features
34 |
35 | - Add named snapshots #18
36 | - Add support for file snapshots #54
37 | - Hide empty output in the Django test runner #60
38 |
39 | ### Bug fixes
40 |
41 | - Fix snapshot-update with nose #19
42 | - Fix comparisons again objects stored as GenericRepr #20
43 | - Fix setting snapshot_should_update on other TestCases #33
44 | - Fix using non-ASCII characters #31
45 | - Fix fail silently when snapshot files are invalid #45
46 | - Remove unused snapshots in the Django runner #43
47 | - Fix python3 multiline unicode snapshots #46
48 | - Fix checks against falsy snapshots #50
49 | - Various fixes in GenericFormatter and collection formatters #82
50 | - Fix pytest parameterize for multiline strings #87
51 |
52 | ### Other changes
53 |
54 | - Documentation improvements.
55 | - Add wheel distribution #11
56 | - Combine build scripts into a Makemile #83
57 | - Update fastdiff version
58 |
59 |
60 | ## 0.5.0
61 |
62 | * Add django support. Closes #1
63 | - Add `snapshottest.django.TestRunner`
64 | - Add `snapshottest.django.TestCase`
65 | - Add `--snapshot-update` to django test command. You can use `python manage.py test --snapshot-update`
66 | * Fix #3, all dicts are sorted before saving it and before comparing.
67 |
68 | ### Breaking changes
69 |
70 | * Drop support for `python 3.3`. Since django don't support that version of python.
71 | * Since all dicts are sorted, this cloud be a breaking change for your tests.
72 | Use the `--snapshot-update` option to update your tests
73 |
--------------------------------------------------------------------------------
/tests/test_formatter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import pytest
3 | from math import isnan
4 |
5 | from snapshottest.formatter import Formatter
6 |
7 | import unittest.mock
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "text_value, expected",
12 | [
13 | # basics
14 | ("abc", "'abc'"),
15 | ("", "''"),
16 | ("back\\slash", "'back\\\\slash'"),
17 | # various embedded quotes (single line)
18 | ("""it has "double quotes".""", """'it has "double quotes".'"""),
19 | ("""it's got single quotes""", '''"it's got single quotes"'''),
20 | ("""it's got "both quotes".""", """'it\\'s got "both quotes".'"""),
21 | # multiline gets formatted as triple-quoted
22 | ("one\ntwo\n", "'''one\ntwo\n'''"),
23 | ("three\n'''quotes", '"""three\n\'\'\'quotes"""'),
24 | ("so many\"\"\"\n'''quotes", "'''so many\"\"\"\n\\'\\'\\'quotes'''"),
25 | ],
26 | )
27 | def test_text_formatting(text_value, expected):
28 | formatter = Formatter()
29 | formatted = formatter(text_value)
30 | assert formatted == expected
31 |
32 |
33 | @pytest.mark.parametrize(
34 | "text_value, expected",
35 | [
36 | ("encodage précis", "'encodage précis'"),
37 | ("精确的编码", "'精确的编码'"),
38 | # backslash [unicode repr can't just be `"u'{}'".format(value)`]
39 | ("omvänt\\snedstreck", "'omvänt\\\\snedstreck'"),
40 | # multiline
41 | ("ett\ntvå\n", "'''ett\ntvå\n'''"),
42 | ],
43 | )
44 | def test_non_ascii_text_formatting(text_value, expected):
45 | formatter = Formatter()
46 | formatted = formatter(text_value)
47 | assert formatted == expected
48 |
49 |
50 | # https://github.com/syrusakbary/snapshottest/issues/115
51 | def test_can_normalize_unittest_mock_call_object():
52 | formatter = Formatter()
53 | print(formatter.normalize(unittest.mock.call(1, 2, 3)))
54 |
55 |
56 | def test_can_normalize_iterator_objects():
57 | formatter = Formatter()
58 | print(formatter.normalize(x for x in range(3)))
59 |
60 |
61 | @pytest.mark.parametrize(
62 | "value",
63 | [
64 | 0,
65 | 12.7,
66 | True,
67 | False,
68 | None,
69 | float("-inf"),
70 | float("inf"),
71 | ],
72 | )
73 | def test_basic_formatting_parsing(value):
74 | formatter = Formatter()
75 | formatted = formatter(value)
76 | parsed = eval(formatted)
77 | assert parsed == value
78 | assert type(parsed) == type(value)
79 |
80 |
81 | def test_formatting_parsing_nan():
82 | value = float("nan")
83 |
84 | formatter = Formatter()
85 | formatted = formatter(value)
86 | parsed = eval(formatted)
87 | assert isnan(parsed)
88 |
--------------------------------------------------------------------------------
/examples/pytest/test_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from collections import defaultdict
3 |
4 | from snapshottest.file import FileSnapshot
5 |
6 |
7 | def api_client_get(url):
8 | return {
9 | "url": url,
10 | }
11 |
12 |
13 | def test_me_endpoint(snapshot):
14 | """Testing the API for /me"""
15 | my_api_response = api_client_get("/me")
16 | snapshot.assert_match(my_api_response)
17 |
18 |
19 | def test_unicode(snapshot):
20 | """Simple test with unicode"""
21 | expect = u"pépère"
22 | snapshot.assert_match(expect)
23 |
24 |
25 | class SomeObject(object):
26 | def __init__(self, value):
27 | self.value = value
28 |
29 | def __repr__(self):
30 | return "SomeObject({})".format(repr(self.value))
31 |
32 |
33 | def test_object(snapshot):
34 | """
35 | Test a snapshot with a custom object. The object will be represented in the
36 | snapshot using `snapshottest.GenericRepr`. The snapshot will only match if the
37 | object's repr remains the same.
38 | """
39 | test_value = SomeObject(3)
40 | snapshot.assert_match(test_value)
41 |
42 |
43 | def test_file(snapshot, tmpdir):
44 | """
45 | Test a file snapshot. The file contents will be saved in a sub-folder of the
46 | snapshots folder. Useful for large files (e.g. media files) that aren't suitable
47 | for storage as text inside the snap_***.py file.
48 | """
49 | temp_file = tmpdir.join("example.txt")
50 | temp_file.write("Hello, world!")
51 | snapshot.assert_match(FileSnapshot(str(temp_file)))
52 |
53 |
54 | def test_multiple_files(snapshot, tmpdir):
55 | """
56 | Each file is stored separately with the snapshot's name inside the module's file
57 | snapshots folder.
58 | """
59 | temp_file1 = tmpdir.join("example1.txt")
60 | temp_file1.write("Hello, world 1!")
61 | snapshot.assert_match(FileSnapshot(str(temp_file1)))
62 |
63 | temp_file1 = tmpdir.join("example2.txt")
64 | temp_file1.write("Hello, world 2!")
65 | snapshot.assert_match(FileSnapshot(str(temp_file1)))
66 |
67 |
68 | class ObjectWithBadRepr(object):
69 | def __repr__(self):
70 | return "#"
71 |
72 |
73 | def test_nested_objects(snapshot):
74 | obj = ObjectWithBadRepr()
75 |
76 | dict_ = {"key": obj}
77 | defaultdict_ = defaultdict(list, [("key", [obj])])
78 | list_ = [obj]
79 | tuple_ = (obj,)
80 | set_ = set((obj,))
81 | frozenset_ = frozenset((obj,))
82 |
83 | snapshot.assert_match(dict_, "dict")
84 | snapshot.assert_match(defaultdict_, "defaultdict")
85 | snapshot.assert_match(list_, "list")
86 | snapshot.assert_match(tuple_, "tuple")
87 | snapshot.assert_match(set_, "set")
88 | snapshot.assert_match(frozenset_, "frozenset")
89 |
--------------------------------------------------------------------------------
/snapshottest/file.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import filecmp
4 |
5 | from .formatter import Formatter
6 | from .formatters import BaseFormatter
7 |
8 |
9 | class FileSnapshot(object):
10 | def __init__(self, path):
11 | """
12 | Create a file snapshot pointing to the specified `path`. In a snapshot, `path`
13 | is considered to be relative to the test module's "snapshots" folder. (This is
14 | done to prevent ugly path manipulations inside the snapshot file.)
15 | """
16 | self.path = path
17 |
18 | def __repr__(self):
19 | return "FileSnapshot({})".format(repr(self.path))
20 |
21 | def __eq__(self, other):
22 | return self.path == other.path
23 |
24 |
25 | class FileSnapshotFormatter(BaseFormatter):
26 | def can_format(self, value):
27 | return isinstance(value, FileSnapshot)
28 |
29 | def store(self, test, value):
30 | """
31 | Copy the file from the test location to the snapshot location.
32 |
33 | If the original test file has an extension, the snapshot file will
34 | use the same extension.
35 | """
36 |
37 | file_snapshot_dir = self.get_file_snapshot_dir(test)
38 | if not os.path.exists(file_snapshot_dir):
39 | os.makedirs(file_snapshot_dir, 0o0700)
40 | extension = os.path.splitext(value.path)[1]
41 | snapshot_file = os.path.join(file_snapshot_dir, test.test_name) + extension
42 | shutil.copy(value.path, snapshot_file)
43 | relative_snapshot_filename = os.path.relpath(
44 | snapshot_file, test.module.snapshot_dir
45 | )
46 | return FileSnapshot(relative_snapshot_filename)
47 |
48 | def get_imports(self):
49 | return (("snapshottest.file", "FileSnapshot"),)
50 |
51 | def format(self, value, indent, formatter):
52 | return repr(value)
53 |
54 | def assert_value_matches_snapshot(
55 | self, test, test_value, snapshot_value, formatter
56 | ):
57 | snapshot_path = os.path.join(test.module.snapshot_dir, snapshot_value.path)
58 | files_identical = filecmp.cmp(test_value.path, snapshot_path, shallow=False)
59 | assert files_identical, "Stored file differs from test file"
60 |
61 | @staticmethod
62 | def get_file_snapshot_dir(test):
63 | """
64 | Get the directory for storing file snapshots for `test`.
65 | Snapshot files are stored under:
66 | snapshots/snap_/
67 | Right next to where the snapshot module is stored:
68 | snapshots/snap_.py
69 | """
70 | return os.path.join(test.module.snapshot_dir, test.module.module)
71 |
72 |
73 | Formatter.register_formatter(FileSnapshotFormatter())
74 |
--------------------------------------------------------------------------------
/tests/test_pytest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from snapshottest.pytest import PyTestSnapshotTest
4 |
5 |
6 | @pytest.fixture
7 | def options():
8 | return {}
9 |
10 |
11 | @pytest.fixture
12 | def _apply_options(request, monkeypatch, options):
13 | for k, v in options.items():
14 | monkeypatch.setattr(request.config, k, v, raising=False)
15 |
16 |
17 | @pytest.fixture
18 | def pytest_snapshot_test(request, _apply_options):
19 | return PyTestSnapshotTest(request)
20 |
21 |
22 | class TestPyTestSnapShotTest:
23 | def test_property_test_name(self, pytest_snapshot_test):
24 | pytest_snapshot_test.assert_match("counter")
25 | assert (
26 | pytest_snapshot_test.test_name
27 | == "TestPyTestSnapShotTest.test_property_test_name 1"
28 | )
29 |
30 | pytest_snapshot_test.assert_match("named", "named_test")
31 | assert (
32 | pytest_snapshot_test.test_name
33 | == "TestPyTestSnapShotTest.test_property_test_name named_test"
34 | )
35 |
36 | pytest_snapshot_test.assert_match("counter")
37 | assert (
38 | pytest_snapshot_test.test_name
39 | == "TestPyTestSnapShotTest.test_property_test_name 2"
40 | )
41 |
42 |
43 | def test_pytest_snapshottest_property_test_name(pytest_snapshot_test):
44 | pytest_snapshot_test.assert_match("counter")
45 | assert (
46 | pytest_snapshot_test.test_name
47 | == "test_pytest_snapshottest_property_test_name 1"
48 | )
49 |
50 | pytest_snapshot_test.assert_match("named", "named_test")
51 | assert (
52 | pytest_snapshot_test.test_name
53 | == "test_pytest_snapshottest_property_test_name named_test"
54 | )
55 |
56 | pytest_snapshot_test.assert_match("counter")
57 | assert (
58 | pytest_snapshot_test.test_name
59 | == "test_pytest_snapshottest_property_test_name 2"
60 | )
61 |
62 |
63 | @pytest.mark.parametrize("arg", ["single line string"])
64 | def test_pytest_snapshottest_property_test_name_parametrize_singleline(
65 | pytest_snapshot_test, arg
66 | ):
67 | pytest_snapshot_test.assert_match("counter")
68 | assert (
69 | pytest_snapshot_test.test_name
70 | == "test_pytest_snapshottest_property_test_name_parametrize_singleline"
71 | "[single line string] 1"
72 | )
73 |
74 |
75 | @pytest.mark.parametrize(
76 | "arg",
77 | [
78 | """
79 | multi
80 | line
81 | string
82 | """
83 | ],
84 | )
85 | def test_pytest_snapshottest_property_test_name_parametrize_multiline(
86 | pytest_snapshot_test, arg
87 | ):
88 | pytest_snapshot_test.assert_match("counter")
89 | assert (
90 | pytest_snapshot_test.test_name
91 | == "test_pytest_snapshottest_property_test_name_parametrize_multiline"
92 | "[ multi line string ] 1"
93 | )
94 |
--------------------------------------------------------------------------------
/snapshottest/pytest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import re
3 |
4 | from .module import SnapshotModule, SnapshotTest
5 | from .diff import PrettyDiff
6 | from .reporting import reporting_lines, diff_report
7 |
8 |
9 | def pytest_addoption(parser):
10 | group = parser.getgroup("snapshottest")
11 | group.addoption(
12 | "--snapshot-update",
13 | action="store_true",
14 | default=False,
15 | dest="snapshot_update",
16 | help="Update the snapshots.",
17 | )
18 | group.addoption(
19 | "--snapshot-verbose",
20 | action="store_true",
21 | default=False,
22 | help="Dump diagnostic and progress information.",
23 | )
24 |
25 |
26 | class PyTestSnapshotTest(SnapshotTest):
27 | def __init__(self, request=None):
28 | self.request = request
29 | super(PyTestSnapshotTest, self).__init__()
30 |
31 | @property
32 | def module(self):
33 | return SnapshotModule.get_module_for_testpath(self.request.node.fspath.strpath)
34 |
35 | @property
36 | def update(self):
37 | return self.request.config.option.snapshot_update
38 |
39 | @property
40 | def test_name(self):
41 | cls_name = getattr(self.request.node.cls, "__name__", "")
42 | flattened_node_name = re.sub(
43 | r"\s+", " ", self.request.node.name.replace(r"\n", " ")
44 | )
45 | return "{}{} {}".format(
46 | "{}.".format(cls_name) if cls_name else "",
47 | flattened_node_name,
48 | self.curr_snapshot,
49 | )
50 |
51 |
52 | class SnapshotSession(object):
53 | def __init__(self, config):
54 | self.verbose = config.getoption("snapshot_verbose")
55 | self.config = config
56 |
57 | def display(self, tr):
58 | if not SnapshotModule.has_snapshots():
59 | return
60 |
61 | tr.write_sep("=", "SnapshotTest summary")
62 |
63 | for line in reporting_lines("pytest"):
64 | tr.write_line(line)
65 |
66 |
67 | def pytest_assertrepr_compare(op, left, right):
68 | if isinstance(left, PrettyDiff) and op == "==":
69 | return diff_report(left, right)
70 |
71 |
72 | @pytest.fixture
73 | def snapshot(request):
74 | with PyTestSnapshotTest(request) as snapshot_test:
75 | yield snapshot_test
76 |
77 |
78 | def pytest_terminal_summary(terminalreporter):
79 | if terminalreporter.config.option.snapshot_update:
80 | for module in SnapshotModule.get_modules():
81 | module.delete_unvisited()
82 | module.save()
83 |
84 | terminalreporter.config._snapshotsession.display(terminalreporter)
85 |
86 |
87 | # force the other plugins to initialise first
88 | # (fixes issue with capture not being properly initialised)
89 | @pytest.mark.trylast
90 | def pytest_configure(config):
91 | config._snapshotsession = SnapshotSession(config)
92 | # config.pluginmanager.register(bs, "snapshottest")
93 |
--------------------------------------------------------------------------------
/examples/django_project/django_project/settings.py:
--------------------------------------------------------------------------------
1 | """
2 | Django settings for django_project project.
3 |
4 | Generated by 'django-admin startproject' using Django 1.11.1.
5 |
6 | For more information on this file, see
7 | https://docs.djangoproject.com/en/1.11/topics/settings/
8 |
9 | For the full list of settings and their values, see
10 | https://docs.djangoproject.com/en/1.11/ref/settings/
11 | """
12 |
13 | import os
14 |
15 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...)
16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17 |
18 |
19 | # Quick-start development settings - unsuitable for production
20 | # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/
21 |
22 | # SECURITY WARNING: keep the secret key used in production secret!
23 | SECRET_KEY = "$5@im(@s1+p9a&ob#1osrq%*-sue%90o6q*cf0)$h@urtql^4@"
24 |
25 | # SECURITY WARNING: don't run with debug turned on in production!
26 | DEBUG = True
27 |
28 | ALLOWED_HOSTS = []
29 |
30 |
31 | # Application definition
32 |
33 | INSTALLED_APPS = [
34 | "django.contrib.admin",
35 | "django.contrib.auth",
36 | "django.contrib.contenttypes",
37 | "django.contrib.sessions",
38 | "django.contrib.messages",
39 | "django.contrib.staticfiles",
40 | "lists",
41 | ]
42 |
43 | MIDDLEWARE = [
44 | "django.middleware.security.SecurityMiddleware",
45 | "django.contrib.sessions.middleware.SessionMiddleware",
46 | "django.middleware.common.CommonMiddleware",
47 | "django.middleware.csrf.CsrfViewMiddleware",
48 | "django.contrib.auth.middleware.AuthenticationMiddleware",
49 | "django.contrib.messages.middleware.MessageMiddleware",
50 | "django.middleware.clickjacking.XFrameOptionsMiddleware",
51 | ]
52 |
53 | ROOT_URLCONF = "django_project.urls"
54 |
55 | TEMPLATES = [
56 | {
57 | "BACKEND": "django.template.backends.django.DjangoTemplates",
58 | "DIRS": [],
59 | "APP_DIRS": True,
60 | "OPTIONS": {
61 | "context_processors": [
62 | "django.template.context_processors.debug",
63 | "django.template.context_processors.request",
64 | "django.contrib.auth.context_processors.auth",
65 | "django.contrib.messages.context_processors.messages",
66 | ],
67 | },
68 | },
69 | ]
70 |
71 | WSGI_APPLICATION = "django_project.wsgi.application"
72 |
73 | TEST_RUNNER = "snapshottest.django.TestRunner"
74 |
75 | # Database
76 | # https://docs.djangoproject.com/en/1.11/ref/settings/#databases
77 |
78 | DATABASES = {
79 | "default": {
80 | "ENGINE": "django.db.backends.sqlite3",
81 | }
82 | }
83 |
84 |
85 | # Password validation
86 | # https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators
87 |
88 | AUTH_PASSWORD_VALIDATORS = [
89 | {
90 | "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", # noqa: E501
91 | },
92 | {
93 | "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
94 | },
95 | {
96 | "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
97 | },
98 | {
99 | "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
100 | },
101 | ]
102 |
103 |
104 | # Internationalization
105 | # https://docs.djangoproject.com/en/1.11/topics/i18n/
106 |
107 | LANGUAGE_CODE = "en-us"
108 |
109 | TIME_ZONE = "UTC"
110 |
111 | USE_I18N = True
112 |
113 | USE_L10N = True
114 |
115 | USE_TZ = True
116 |
117 |
118 | # Static files (CSS, JavaScript, Images)
119 | # https://docs.djangoproject.com/en/1.11/howto/static-files/
120 |
121 | STATIC_URL = "/static/"
122 |
--------------------------------------------------------------------------------
/snapshottest/unittest.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import inspect
3 |
4 | from .module import SnapshotModule, SnapshotTest
5 | from .diff import PrettyDiff
6 | from .reporting import diff_report
7 |
8 |
9 | class UnitTestSnapshotTest(SnapshotTest):
10 | def __init__(self, test_class, test_id, test_filepath, should_update, assertEqual):
11 | self.test_class = test_class
12 | self.test_id = test_id
13 | self.test_filepath = test_filepath
14 | self.assertEqual = assertEqual
15 | self.should_update = should_update
16 | super(UnitTestSnapshotTest, self).__init__()
17 |
18 | @property
19 | def module(self):
20 | return SnapshotModule.get_module_for_testpath(self.test_filepath)
21 |
22 | @property
23 | def update(self):
24 | return self.should_update
25 |
26 | def assert_equals(self, value, snapshot):
27 | self.assertEqual(value, snapshot)
28 |
29 | @property
30 | def test_name(self):
31 | class_name = self.test_class.__name__
32 | test_name = self.test_id.split(".")[-1]
33 | return "{}::{} {}".format(class_name, test_name, self.curr_snapshot)
34 |
35 |
36 | # Inspired by https://gist.github.com/twolfson/13f5f5784f67fd49b245
37 | class TestCase(unittest.TestCase):
38 |
39 | snapshot_should_update = False
40 |
41 | @classmethod
42 | def setUpClass(cls):
43 | """On inherited classes, run our `setUp` method"""
44 | cls._snapshot_tests = []
45 | cls._snapshot_file = inspect.getfile(cls)
46 |
47 | if cls is not TestCase and cls.setUp is not TestCase.setUp:
48 | orig_setUp = cls.setUp
49 | orig_tearDown = cls.tearDown
50 |
51 | def setUpOverride(self, *args, **kwargs):
52 | TestCase.setUp(self)
53 | return orig_setUp(self, *args, **kwargs)
54 |
55 | def tearDownOverride(self, *args, **kwargs):
56 | TestCase.tearDown(self)
57 | return orig_tearDown(self, *args, **kwargs)
58 |
59 | cls.setUp = setUpOverride
60 | cls.tearDown = tearDownOverride
61 |
62 | super(TestCase, cls).setUpClass()
63 |
64 | def comparePrettyDifs(self, obj1, obj2, msg):
65 | # self
66 | # assert obj1 == obj2
67 | if not (obj1 == obj2):
68 | raise self.failureException("\n".join(diff_report(obj1, obj2)))
69 | # raise self.failureException("DIFF")
70 |
71 | @classmethod
72 | def tearDownClass(cls):
73 | if cls._snapshot_tests:
74 | module = SnapshotModule.get_module_for_testpath(cls._snapshot_file)
75 | module.save()
76 | super(TestCase, cls).tearDownClass()
77 |
78 | def setUp(self):
79 | """Do some custom setup"""
80 | super().setUp()
81 | # print dir(self.__module__)
82 | self.addTypeEqualityFunc(PrettyDiff, self.comparePrettyDifs)
83 | self._snapshot = UnitTestSnapshotTest(
84 | test_class=self.__class__,
85 | test_id=self.id(),
86 | test_filepath=self._snapshot_file,
87 | should_update=self.snapshot_should_update,
88 | assertEqual=self.assertEqual,
89 | )
90 | self._snapshot_tests.append(self._snapshot)
91 | SnapshotTest._current_tester = self._snapshot
92 |
93 | def tearDown(self):
94 | """Do some custom setup"""
95 | # print dir(self.__module__)
96 | SnapshotTest._current_tester = None
97 | self._snapshot = None
98 |
99 | def assert_match_snapshot(self, value, name=""):
100 | self._snapshot.assert_match(value, name=name)
101 |
102 | assertMatchSnapshot = assert_match_snapshot
103 |
--------------------------------------------------------------------------------
/tests/test_snapshot_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from collections import OrderedDict
3 |
4 | from snapshottest.module import SnapshotModule, SnapshotTest
5 |
6 |
7 | class GenericSnapshotTest(SnapshotTest):
8 | """A concrete SnapshotTest implementation for no particular testing framework"""
9 |
10 | def __init__(self, snapshot_module, update=False, current_test_id=None):
11 | self._generic_options = {
12 | "snapshot_module": snapshot_module,
13 | "update": update,
14 | "current_test_id": current_test_id or "test_mocked",
15 | }
16 | super(GenericSnapshotTest, self).__init__()
17 |
18 | @property
19 | def module(self):
20 | return self._generic_options["snapshot_module"]
21 |
22 | @property
23 | def update(self):
24 | return self._generic_options["update"]
25 |
26 | @property
27 | def test_name(self):
28 | return "{} {}".format(
29 | self._generic_options["current_test_id"], self.curr_snapshot
30 | )
31 |
32 | def reinitialize(self):
33 | """Reset internal state, as though starting a new test run"""
34 | super(GenericSnapshotTest, self).__init__()
35 |
36 |
37 | def assert_snapshot_test_ran(snapshot_test, test_name=None):
38 | test_name = test_name or snapshot_test.test_name
39 | assert test_name in snapshot_test.module.visited_snapshots
40 |
41 |
42 | def assert_snapshot_test_succeeded(snapshot_test, test_name=None):
43 | test_name = test_name or snapshot_test.test_name
44 | assert_snapshot_test_ran(snapshot_test, test_name)
45 | assert test_name not in snapshot_test.module.failed_snapshots
46 |
47 |
48 | def assert_snapshot_test_failed(snapshot_test, test_name=None):
49 | test_name = test_name or snapshot_test.test_name
50 | assert_snapshot_test_ran(snapshot_test, test_name)
51 | assert test_name in snapshot_test.module.failed_snapshots
52 |
53 |
54 | @pytest.fixture(name="snapshot_test")
55 | def fixture_snapshot_test(tmpdir):
56 | filepath = tmpdir.join("snap_mocked.py")
57 | module = SnapshotModule("tests.snapshots.snap_mocked", str(filepath))
58 | return GenericSnapshotTest(module)
59 |
60 |
61 | SNAPSHOTABLE_VALUES = [
62 | "abc",
63 | b"abc",
64 | 123,
65 | 123.456,
66 | {"a": 1, "b": 2, "c": 3}, # dict
67 | ["a", "b", "c"], # list
68 | {"a", "b", "c"}, # set
69 | ("a", "b", "c"), # tuple
70 | ("a",), # tuple only have one element
71 | # Falsy values:
72 | None,
73 | False,
74 | "",
75 | b"",
76 | dict(),
77 | list(),
78 | set(),
79 | tuple(),
80 | 0,
81 | 0.0,
82 | # dict subclasses:
83 | # (Make sure snapshots don't just coerce to dict for comparison.)
84 | OrderedDict([("a", 1), ("b", 2), ("c", 3)]), # same items as earlier dict
85 | OrderedDict([("c", 3), ("b", 2), ("a", 1)]), # same items, different order
86 | ]
87 |
88 |
89 | @pytest.mark.parametrize("value", SNAPSHOTABLE_VALUES, ids=repr)
90 | def test_snapshot_matches_itself(snapshot_test, value):
91 | # first run stores the value as the snapshot
92 | snapshot_test.assert_match(value)
93 | assert_snapshot_test_succeeded(snapshot_test)
94 |
95 | # second run should compare stored snapshot and also succeed
96 | snapshot_test.reinitialize()
97 | snapshot_test.assert_match(value)
98 | assert_snapshot_test_succeeded(snapshot_test)
99 |
100 |
101 | @pytest.mark.parametrize(
102 | "value, other_value",
103 | [
104 | pytest.param(
105 | value,
106 | other_value,
107 | id="snapshot {!r} shouldn't match {!r}".format(value, other_value),
108 | )
109 | for value in SNAPSHOTABLE_VALUES
110 | for other_value in SNAPSHOTABLE_VALUES
111 | if other_value != value
112 | ],
113 | )
114 | def test_snapshot_does_not_match_other_values(snapshot_test, value, other_value):
115 | # first run stores the value as the snapshot
116 | snapshot_test.assert_match(value)
117 | assert_snapshot_test_succeeded(snapshot_test)
118 |
119 | # second run tries to match other_value, should fail
120 | snapshot_test.reinitialize()
121 | with pytest.raises(AssertionError):
122 | snapshot_test.assert_match(other_value)
123 | assert_snapshot_test_failed(snapshot_test)
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SnapshotTest [![travis][travis-image]][travis-url] [![pypi][pypi-image]][pypi-url]
2 |
3 | [travis-image]: https://img.shields.io/travis/syrusakbary/snapshottest.svg?style=flat
4 | [travis-url]: https://travis-ci.org/syrusakbary/snapshottest
5 | [pypi-image]: https://img.shields.io/pypi/v/snapshottest.svg?style=flat
6 | [pypi-url]: https://pypi.python.org/pypi/snapshottest
7 |
8 |
9 | Snapshot testing is a way to test your APIs without writing actual test cases.
10 |
11 | 1. A snapshot is a single state of your API, saved in a file.
12 | 2. You have a set of snapshots for your API endpoints.
13 | 3. Once you add a new feature, you can generate *automatically* new snapshots for the updated API.
14 |
15 | ## Installation
16 |
17 | $ pip install snapshottest
18 |
19 |
20 | ## Usage with unittest/nose
21 |
22 | ```python
23 | from snapshottest import TestCase
24 |
25 | class APITestCase(TestCase):
26 | def test_api_me(self):
27 | """Testing the API for /me"""
28 | my_api_response = api.client.get('/me')
29 | self.assertMatchSnapshot(my_api_response)
30 |
31 | # Set custom snapshot name: `gpg_response`
32 | my_gpg_response = api.client.get('/me?gpg_key')
33 | self.assertMatchSnapshot(my_gpg_response, 'gpg_response')
34 | ```
35 |
36 | If you want to update the snapshots automatically you can use the `nosetests --snapshot-update`.
37 |
38 | Check the [Unittest example](https://github.com/syrusakbary/snapshottest/tree/master/examples/unittest).
39 |
40 | ## Usage with pytest
41 |
42 | ```python
43 | def test_mything(snapshot):
44 | """Testing the API for /me"""
45 | my_api_response = api.client.get('/me')
46 | snapshot.assert_match(my_api_response)
47 |
48 | # Set custom snapshot name: `gpg_response`
49 | my_gpg_response = api.client.get('/me?gpg_key')
50 | snapshot.assert_match(my_gpg_response, 'gpg_response')
51 | ```
52 |
53 | If you want to update the snapshots automatically you can use the `--snapshot-update` config.
54 |
55 | Check the [Pytest example](https://github.com/syrusakbary/snapshottest/tree/master/examples/pytest).
56 |
57 | ## Usage with django
58 | Add to your settings:
59 | ```python
60 | TEST_RUNNER = 'snapshottest.django.TestRunner'
61 | ```
62 | To create your snapshottest:
63 | ```python
64 | from snapshottest.django import TestCase
65 |
66 | class APITestCase(TestCase):
67 | def test_api_me(self):
68 | """Testing the API for /me"""
69 | my_api_response = api.client.get('/me')
70 | self.assertMatchSnapshot(my_api_response)
71 | ```
72 | If you want to update the snapshots automatically you can use the `python manage.py test --snapshot-update`.
73 | Check the [Django example](https://github.com/syrusakbary/snapshottest/tree/master/examples/django_project).
74 |
75 | ## Disabling terminal colors
76 |
77 | Set the environment variable `ANSI_COLORS_DISABLED` (to any value), e.g.
78 |
79 | ANSI_COLORS_DISABLED=1 pytest
80 |
81 |
82 | # Contributing
83 |
84 | After cloning this repo and configuring a virtualenv for snapshottest (optional, but highly recommended), ensure dependencies are installed by running:
85 |
86 | ```sh
87 | make develop
88 | ```
89 |
90 | After developing, ensure your code is formatted properly by running:
91 |
92 | ```sh
93 | make format-fix
94 | ```
95 |
96 | and then run the full test suite with:
97 |
98 | ```sh
99 | make lint
100 | # and
101 | make test
102 | ```
103 |
104 | To test locally on all supported Python versions, you can use
105 | [tox](https://tox.readthedocs.io/):
106 |
107 | ```sh
108 | pip install tox # (if you haven't before)
109 | tox
110 | ```
111 |
112 | # Notes
113 |
114 | This package is heavily inspired in [jest snapshot testing](https://facebook.github.io/jest/docs/snapshot-testing.html).
115 |
116 | # Reasons to use this package
117 |
118 | > Most of this content is taken from the [Jest snapshot blogpost](https://facebook.github.io/jest/blog/2016/07/27/jest-14.html).
119 |
120 | We want to make it as frictionless as possible to write good tests that are useful.
121 | We observed that when engineers are provided with ready-to-use tools, they end up writing more tests, which in turn results in stable and healthy code bases.
122 |
123 | However engineers frequently spend more time writing a test than the component itself. As a result many people stopped writing tests altogether which eventually led to instabilities.
124 |
125 | A typical snapshot test case for a mobile app renders a UI component, takes a screenshot, then compares it to a reference image stored alongside the test. The test will fail if the two images do not match: either the change is unexpected, or the screenshot needs to be updated to the new version of the UI component.
126 |
127 |
128 | ## Snapshot Testing with SnapshotTest
129 |
130 | A similar approach can be taken when it comes to testing your APIs.
131 | Instead of rendering the graphical UI, which would require building the entire app, you can use a test renderer to quickly generate a serializable value for your API response.
132 |
133 |
134 | ## License
135 |
136 | [MIT License](https://github.com/syrusakbary/snapshottest/blob/master/LICENSE)
137 |
138 | [![coveralls][coveralls-image]][coveralls-url]
139 |
140 | [coveralls-image]: https://coveralls.io/repos/syrusakbary/snapshottest/badge.svg?branch=master&service=github
141 | [coveralls-url]: https://coveralls.io/github/syrusakbary/snapshottest?branch=master
142 |
--------------------------------------------------------------------------------
/snapshottest/formatters.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import defaultdict
3 |
4 | from .sorted_dict import SortedDict
5 | from .generic_repr import GenericRepr
6 |
7 |
8 | class BaseFormatter(object):
9 | def can_format(self, value):
10 | raise NotImplementedError()
11 |
12 | def format(self, value, indent, formatter):
13 | raise NotImplementedError()
14 |
15 | def get_imports(self):
16 | return ()
17 |
18 | def assert_value_matches_snapshot(
19 | self, test, test_value, snapshot_value, formatter
20 | ):
21 | test.assert_equals(formatter.normalize(test_value), snapshot_value)
22 |
23 | def store(self, test, value):
24 | return value
25 |
26 | def normalize(self, value, formatter):
27 | return value
28 |
29 |
30 | class TypeFormatter(BaseFormatter):
31 | def __init__(self, types, format_func):
32 | self.types = types
33 | self.format_func = format_func
34 |
35 | def can_format(self, value):
36 | return isinstance(value, self.types)
37 |
38 | def format(self, value, indent, formatter):
39 | return self.format_func(value, indent, formatter)
40 |
41 |
42 | class CollectionFormatter(TypeFormatter):
43 | def normalize(self, value, formatter):
44 | iterator = iter(value.items()) if isinstance(value, dict) else iter(value)
45 | # https://github.com/syrusakbary/snapshottest/issues/115
46 | # Normally we shouldn't need to turn this into a list, but some iterable
47 | # constructors need a list not an iterator (e.g. unittest.mock.call).
48 | return value.__class__([formatter.normalize(item) for item in iterator])
49 |
50 |
51 | class DefaultDictFormatter(TypeFormatter):
52 | def normalize(self, value, formatter):
53 | return defaultdict(
54 | value.default_factory, (formatter.normalize(item) for item in value.items())
55 | )
56 |
57 |
58 | def trepr(s):
59 | text = "\n".join([repr(line).lstrip("u")[1:-1] for line in s.split("\n")])
60 | quotes, dquotes = "'''", '"""'
61 | if quotes in text:
62 | if dquotes in text:
63 | text = text.replace(quotes, "\\'\\'\\'")
64 | else:
65 | quotes = dquotes
66 | return "%s%s%s" % (quotes, text, quotes)
67 |
68 |
69 | def format_none(value, indent, formatter):
70 | return "None"
71 |
72 |
73 | def format_str(value, indent, formatter):
74 | if "\n" in value:
75 | # Is a multiline string, so we use '''{}''' for the repr
76 | return trepr(value)
77 |
78 | # Snapshots are saved with `from __future__ import unicode_literals`,
79 | # so the `u'...'` repr is unnecessary, even on Python 2
80 | return repr(value).lstrip("u")
81 |
82 |
83 | def format_float(value, indent, formatter):
84 | if math.isinf(value) or math.isnan(value):
85 | return 'float("%s")' % repr(value)
86 | return repr(value)
87 |
88 |
89 | def format_std_type(value, indent, formatter):
90 | return repr(value)
91 |
92 |
93 | def format_dict(value, indent, formatter):
94 | value = SortedDict(value)
95 | items = [
96 | formatter.lfchar
97 | + formatter.htchar * (indent + 1)
98 | + formatter.format(key, indent)
99 | + ": "
100 | + formatter.format(value[key], indent + 1)
101 | for key in value
102 | ]
103 | return "{%s}" % (",".join(items) + formatter.lfchar + formatter.htchar * indent)
104 |
105 |
106 | def format_list(value, indent, formatter):
107 | return "[%s]" % format_sequence(value, indent, formatter)
108 |
109 |
110 | def format_sequence(value, indent, formatter):
111 | items = [
112 | formatter.lfchar
113 | + formatter.htchar * (indent + 1)
114 | + formatter.format(item, indent + 1)
115 | for item in value
116 | ]
117 | return ",".join(items) + formatter.lfchar + formatter.htchar * indent
118 |
119 |
120 | def format_tuple(value, indent, formatter):
121 | return "(%s%s" % (
122 | format_sequence(value, indent, formatter),
123 | ",)" if len(value) == 1 else ")",
124 | )
125 |
126 |
127 | def format_set(value, indent, formatter):
128 | return "set([%s])" % format_sequence(value, indent, formatter)
129 |
130 |
131 | def format_frozenset(value, indent, formatter):
132 | return "frozenset([%s])" % format_sequence(value, indent, formatter)
133 |
134 |
135 | class GenericFormatter(BaseFormatter):
136 | def can_format(self, value):
137 | return True
138 |
139 | def store(self, test, value):
140 | return GenericRepr.from_value(value)
141 |
142 | def normalize(self, value, formatter):
143 | return GenericRepr.from_value(value)
144 |
145 | def format(self, value, indent, formatter):
146 | if not isinstance(value, GenericRepr):
147 | value = GenericRepr.from_value(value)
148 | return repr(value)
149 |
150 | def get_imports(self):
151 | return [("snapshottest", "GenericRepr")]
152 |
153 | def assert_value_matches_snapshot(
154 | self, test, test_value, snapshot_value, formatter
155 | ):
156 | test_value = GenericRepr.from_value(test_value)
157 | # Assert equality between the representations to provide a nice textual diff.
158 | test.assert_equals(test_value.representation, snapshot_value.representation)
159 |
160 |
161 | def default_formatters():
162 | return [
163 | TypeFormatter(type(None), format_none),
164 | DefaultDictFormatter(defaultdict, format_dict),
165 | CollectionFormatter(dict, format_dict),
166 | CollectionFormatter(tuple, format_tuple),
167 | CollectionFormatter(list, format_list),
168 | CollectionFormatter(set, format_set),
169 | CollectionFormatter(frozenset, format_frozenset),
170 | TypeFormatter((str,), format_str),
171 | TypeFormatter((float,), format_float),
172 | TypeFormatter((int, complex, bool, bytes), format_std_type),
173 | GenericFormatter(),
174 | ]
175 |
--------------------------------------------------------------------------------
/snapshottest/module.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | import errno
3 | import os
4 | import imp
5 | from collections import defaultdict
6 | import logging
7 |
8 | from .snapshot import Snapshot
9 | from .formatter import Formatter
10 | from .error import SnapshotNotFound
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def _escape_quotes(text):
17 | return text.replace("'", "\\'")
18 |
19 |
20 | class SnapshotModule(object):
21 | _snapshot_modules = {}
22 |
23 | def __init__(self, module, filepath):
24 | self._original_snapshot = None
25 | self._snapshots = None
26 | self.module = module
27 | self.filepath = filepath
28 | self.imports = defaultdict(set)
29 | self.visited_snapshots = set()
30 | self.new_snapshots = set()
31 | self.failed_snapshots = set()
32 | self.imports["snapshottest"].add("Snapshot")
33 |
34 | def load_snapshots(self):
35 | try:
36 | source = imp.load_source(self.module, self.filepath)
37 | # except FileNotFoundError: # Python 3
38 | except (IOError, OSError) as err:
39 | if err.errno == errno.ENOENT:
40 | return Snapshot()
41 | else:
42 | raise
43 | else:
44 | assert isinstance(source.snapshots, Snapshot)
45 | return source.snapshots
46 |
47 | def visit(self, snapshot_name):
48 | self.visited_snapshots.add(snapshot_name)
49 |
50 | def delete_unvisited(self):
51 | for unvisited in self.unvisited_snapshots:
52 | del self.snapshots[unvisited]
53 |
54 | @property
55 | def unvisited_snapshots(self):
56 | return set(self.snapshots.keys()) - self.visited_snapshots
57 |
58 | @classmethod
59 | def total_unvisited_snapshots(cls):
60 | unvisited_snapshots = 0
61 | unvisited_modules = 0
62 | for module in cls.get_modules():
63 | unvisited_snapshot_len = len(module.unvisited_snapshots)
64 | unvisited_snapshots += unvisited_snapshot_len
65 | unvisited_modules += min(unvisited_snapshot_len, 1)
66 |
67 | return unvisited_snapshots, unvisited_modules
68 |
69 | @classmethod
70 | def get_modules(cls):
71 | return SnapshotModule._snapshot_modules.values()
72 |
73 | @classmethod
74 | def stats_for_module(cls, getter):
75 | count_snapshots = 0
76 | count_modules = 0
77 | for module in SnapshotModule._snapshot_modules.values():
78 | length = getter(module)
79 | count_snapshots += length
80 | count_modules += min(length, 1)
81 |
82 | return count_snapshots, count_modules
83 |
84 | @classmethod
85 | def stats_unvisited_snapshots(cls):
86 | return cls.stats_for_module(lambda module: len(module.unvisited_snapshots))
87 |
88 | @classmethod
89 | def stats_visited_snapshots(cls):
90 | return cls.stats_for_module(lambda module: len(module.visited_snapshots))
91 |
92 | @classmethod
93 | def stats_new_snapshots(cls):
94 | return cls.stats_for_module(lambda module: len(module.new_snapshots))
95 |
96 | @classmethod
97 | def stats_failed_snapshots(cls):
98 | return cls.stats_for_module(lambda module: len(module.failed_snapshots))
99 |
100 | @classmethod
101 | def stats_successful_snapshots(cls):
102 | stats_visited = cls.stats_visited_snapshots()
103 | stats_failed = cls.stats_failed_snapshots()
104 | return stats_visited[0] - stats_failed[0]
105 |
106 | @classmethod
107 | def has_snapshots(cls):
108 | return cls.stats_visited_snapshots()[0] > 0
109 |
110 | @property
111 | def original_snapshot(self):
112 | if not self._original_snapshot:
113 | self._original_snapshot = self.load_snapshots()
114 | return self._original_snapshot
115 |
116 | @property
117 | def snapshots(self):
118 | if not self._snapshots:
119 | self._snapshots = Snapshot(self.original_snapshot)
120 | return self._snapshots
121 |
122 | def __getitem__(self, test_name):
123 | try:
124 | return self.snapshots[test_name]
125 | except KeyError:
126 | raise SnapshotNotFound(self, test_name)
127 |
128 | def __setitem__(self, key, value):
129 | if key not in self.snapshots:
130 | # It's a new test
131 | self.new_snapshots.add(key)
132 | self.snapshots[key] = value
133 |
134 | def mark_failed(self, key):
135 | return self.failed_snapshots.add(key)
136 |
137 | @property
138 | def snapshot_dir(self):
139 | return os.path.dirname(self.filepath)
140 |
141 | def save(self):
142 | if self.original_snapshot == self.snapshots:
143 | # If there are no changes, we do nothing
144 | return
145 |
146 | # Create the snapshot dir in case doesn't exist
147 | try:
148 | os.makedirs(self.snapshot_dir, 0o0700)
149 | except (IOError, OSError):
150 | pass
151 |
152 | # Create __init__.py in case doesn't exist
153 | open(os.path.join(self.snapshot_dir, "__init__.py"), "a").close()
154 |
155 | pretty = Formatter(self.imports)
156 |
157 | with codecs.open(self.filepath, "w", encoding="utf-8") as snapshot_file:
158 | snapshots_declarations = [
159 | """snapshots['{}'] = {}""".format(
160 | _escape_quotes(key), pretty(self.snapshots[key])
161 | )
162 | for key in sorted(self.snapshots.keys())
163 | ]
164 |
165 | imports = "\n".join(
166 | [
167 | "from {} import {}".format(
168 | module, ", ".join(sorted(module_imports))
169 | )
170 | for module, module_imports in sorted(self.imports.items())
171 | ]
172 | )
173 | snapshot_file.write(
174 | """# -*- coding: utf-8 -*-
175 | # snapshottest: v1 - https://goo.gl/zC4yUc
176 | from __future__ import unicode_literals
177 |
178 | {}
179 |
180 |
181 | snapshots = Snapshot()
182 |
183 | {}
184 | """.format(
185 | imports, "\n\n".join(snapshots_declarations)
186 | )
187 | )
188 |
189 | @classmethod
190 | def get_module_for_testpath(cls, test_filepath):
191 | if test_filepath not in cls._snapshot_modules:
192 | dirname = os.path.dirname(test_filepath)
193 | snapshot_dir = os.path.join(dirname, "snapshots")
194 |
195 | snapshot_basename = "snap_{}.py".format(
196 | os.path.splitext(os.path.basename(test_filepath))[0]
197 | )
198 | snapshot_filename = os.path.join(snapshot_dir, snapshot_basename)
199 | snapshot_module = "{}".format(os.path.splitext(snapshot_basename)[0])
200 |
201 | cls._snapshot_modules[test_filepath] = SnapshotModule(
202 | snapshot_module, snapshot_filename
203 | )
204 |
205 | return cls._snapshot_modules[test_filepath]
206 |
207 |
208 | class SnapshotTest(object):
209 | _current_tester = None
210 |
211 | def __init__(self):
212 | self.curr_snapshot = ""
213 | self.snapshot_counter = 1
214 |
215 | @property
216 | def module(self):
217 | raise NotImplementedError("module property needs to be implemented")
218 |
219 | @property
220 | def update(self):
221 | return False
222 |
223 | @property
224 | def test_name(self):
225 | raise NotImplementedError("test_name property needs to be implemented")
226 |
227 | def __enter__(self):
228 | SnapshotTest._current_tester = self
229 | return self
230 |
231 | def __exit__(self, type, value, tb):
232 | self.save_changes()
233 | SnapshotTest._current_tester = None
234 |
235 | def visit(self):
236 | self.module.visit(self.test_name)
237 |
238 | def fail(self):
239 | self.module.mark_failed(self.test_name)
240 |
241 | def store(self, data):
242 | formatter = Formatter.get_formatter(data)
243 | data = formatter.store(self, data)
244 | self.module[self.test_name] = data
245 |
246 | def assert_value_matches_snapshot(self, test_value, snapshot_value):
247 | formatter = Formatter.get_formatter(test_value)
248 | formatter.assert_value_matches_snapshot(
249 | self, test_value, snapshot_value, Formatter()
250 | )
251 |
252 | def assert_equals(self, value, snapshot):
253 | assert value == snapshot
254 |
255 | def assert_match(self, value, name=""):
256 | self.curr_snapshot = name or self.snapshot_counter
257 | self.visit()
258 | if self.update:
259 | self.store(value)
260 | else:
261 | try:
262 | prev_snapshot = self.module[self.test_name]
263 | except SnapshotNotFound:
264 | self.store(value) # first time this test has been seen
265 | else:
266 | try:
267 | self.assert_value_matches_snapshot(value, prev_snapshot)
268 | except AssertionError:
269 | self.fail()
270 | raise
271 |
272 | if not name:
273 | self.snapshot_counter += 1
274 |
275 | def save_changes(self):
276 | self.module.save()
277 |
278 |
279 | def assert_match_snapshot(value, name=""):
280 | if not SnapshotTest._current_tester:
281 | raise Exception(
282 | "You need to use assert_match_snapshot in the SnapshotTest context."
283 | )
284 |
285 | SnapshotTest._current_tester.assert_match(value, name)
286 |
--------------------------------------------------------------------------------