├── .gitignore ├── README.md ├── fastunit ├── __init__.py ├── __main__.py ├── case.py ├── loader.py ├── main.py ├── mock.py ├── result.py ├── runner.py ├── signals.py ├── suite.py └── util.py ├── install.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | dist 3 | *egg* 4 | __pycache__ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fastunit 2 | 3 | Async unittest of Python3. Run testcases with coroutine. 4 | 5 | ## Usage 6 | 7 | 1. Download/Git clone source code, and run `./install.sh` (or `python setup.py install`) to install. 8 | 9 | 2. Replace `unittest` with `fastunit` in your code. 10 | 11 | ## Known problem: 12 | 13 | `HTMLTestRunner` is not supported, for it designed for linear testcase. 14 | -------------------------------------------------------------------------------- /fastunit/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python unit testing framework, based on Erich Gamma's JUnit and Kent Beck's 3 | Smalltalk testing framework (used with permission). 4 | 5 | This module contains the core framework classes that form the basis of 6 | specific test cases and suites (TestCase, TestSuite etc.), and also a 7 | text-based utility class for running the tests and reporting the results 8 | (TextTestRunner). 9 | 10 | Simple usage: 11 | 12 | import unittest 13 | 14 | class IntegerArithmeticTestCase(unittest.TestCase): 15 | def testAdd(self): # test method names begin with 'test' 16 | self.assertEqual((1 + 2), 3) 17 | self.assertEqual(0 + 1, 1) 18 | def testMultiply(self): 19 | self.assertEqual((0 * 10), 0) 20 | self.assertEqual((5 * 8), 40) 21 | 22 | if __name__ == '__main__': 23 | unittest.main() 24 | 25 | Further information is available in the bundled documentation, and from 26 | 27 | http://docs.python.org/library/unittest.html 28 | 29 | Copyright (c) 1999-2003 Steve Purcell 30 | Copyright (c) 2003-2010 Python Software Foundation 31 | This module is free software, and you may redistribute it and/or modify 32 | it under the same terms as Python itself, so long as this copyright message 33 | and disclaimer are retained in their original form. 34 | 35 | IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, 36 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF 37 | THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH 38 | DAMAGE. 39 | 40 | THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT 41 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 42 | PARTICULAR PURPOSE. THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, 43 | AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE, 44 | SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 45 | """ 46 | 47 | __all__ = ['TestResult', 'TestCase', 'TestSuite', 48 | 'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main', 49 | 'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless', 50 | 'expectedFailure', 'TextTestResult', 'installHandler', 51 | 'registerResult', 'removeResult', 'removeHandler'] 52 | 53 | # Expose obsolete functions for backwards compatibility 54 | __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) 55 | 56 | __unittest = True 57 | 58 | from .result import TestResult 59 | from .case import (TestCase, FunctionTestCase, SkipTest, skip, skipIf, 60 | skipUnless, expectedFailure) 61 | from .suite import BaseTestSuite, TestSuite 62 | from .loader import (TestLoader, defaultTestLoader, makeSuite, getTestCaseNames, 63 | findTestCases) 64 | from .main import TestProgram, main 65 | from .runner import TextTestRunner, TextTestResult 66 | from .signals import installHandler, registerResult, removeResult, removeHandler 67 | 68 | # deprecated 69 | _TextTestResult = TextTestResult 70 | 71 | # There are no tests here, so don't try to run anything discovered from 72 | # introspecting the symbols (e.g. FunctionTestCase). Instead, all our 73 | # tests come from within unittest.test. 74 | def load_tests(loader, tests, pattern): 75 | import os.path 76 | # top level directory cached on loader instance 77 | this_dir = os.path.dirname(__file__) 78 | return loader.discover(start_dir=this_dir, pattern=pattern) 79 | -------------------------------------------------------------------------------- /fastunit/__main__.py: -------------------------------------------------------------------------------- 1 | """Main entry point""" 2 | 3 | import sys 4 | if sys.argv[0].endswith("__main__.py"): 5 | import os.path 6 | # We change sys.argv[0] to make help message more useful 7 | # use executable without path, unquoted 8 | # (it's just a hint anyway) 9 | # (if you have spaces in your executable you get what you deserve!) 10 | executable = os.path.basename(sys.executable) 11 | sys.argv[0] = executable + " -m unittest" 12 | del os 13 | 14 | __unittest = True 15 | 16 | from .main import main, TestProgram 17 | 18 | main(module=None) 19 | -------------------------------------------------------------------------------- /fastunit/case.py: -------------------------------------------------------------------------------- 1 | """Test case implementation""" 2 | 3 | import sys 4 | import functools 5 | import difflib 6 | import logging 7 | import pprint 8 | import re 9 | import warnings 10 | import collections 11 | import contextlib 12 | import traceback 13 | 14 | from . import result 15 | from .util import (strclass, safe_repr, _count_diff_all_purpose, 16 | _count_diff_hashable, _common_shorten_repr) 17 | 18 | __unittest = True 19 | 20 | _subtest_msg_sentinel = object() 21 | 22 | DIFF_OMITTED = ('\nDiff is %s characters long. ' 23 | 'Set self.maxDiff to None to see it.') 24 | 25 | class SkipTest(Exception): 26 | """ 27 | Raise this exception in a test to skip it. 28 | 29 | Usually you can use TestCase.skipTest() or one of the skipping decorators 30 | instead of raising this directly. 31 | """ 32 | 33 | class _ShouldStop(Exception): 34 | """ 35 | The test should stop. 36 | """ 37 | 38 | class _UnexpectedSuccess(Exception): 39 | """ 40 | The test was supposed to fail, but it didn't! 41 | """ 42 | 43 | 44 | class _Outcome(object): 45 | def __init__(self, result=None): 46 | self.expecting_failure = False 47 | self.result = result 48 | self.result_supports_subtests = hasattr(result, "addSubTest") 49 | self.success = True 50 | self.skipped = [] 51 | self.expectedFailure = None 52 | self.errors = [] 53 | 54 | @contextlib.contextmanager 55 | def testPartExecutor(self, test_case, isTest=False): 56 | old_success = self.success 57 | self.success = True 58 | try: 59 | yield 60 | except KeyboardInterrupt: 61 | raise 62 | except SkipTest as e: 63 | self.success = False 64 | self.skipped.append((test_case, str(e))) 65 | except _ShouldStop: 66 | pass 67 | except: 68 | exc_info = sys.exc_info() 69 | if self.expecting_failure: 70 | self.expectedFailure = exc_info 71 | else: 72 | self.success = False 73 | self.errors.append((test_case, exc_info)) 74 | # explicitly break a reference cycle: 75 | # exc_info -> frame -> exc_info 76 | exc_info = None 77 | else: 78 | if self.result_supports_subtests and self.success: 79 | self.errors.append((test_case, None)) 80 | finally: 81 | self.success = self.success and old_success 82 | 83 | 84 | def _id(obj): 85 | return obj 86 | 87 | def skip(reason): 88 | """ 89 | Unconditionally skip a test. 90 | """ 91 | def decorator(test_item): 92 | if not isinstance(test_item, type): 93 | @functools.wraps(test_item) 94 | def skip_wrapper(*args, **kwargs): 95 | raise SkipTest(reason) 96 | test_item = skip_wrapper 97 | 98 | test_item.__unittest_skip__ = True 99 | test_item.__unittest_skip_why__ = reason 100 | return test_item 101 | return decorator 102 | 103 | def skipIf(condition, reason): 104 | """ 105 | Skip a test if the condition is true. 106 | """ 107 | if condition: 108 | return skip(reason) 109 | return _id 110 | 111 | def skipUnless(condition, reason): 112 | """ 113 | Skip a test unless the condition is true. 114 | """ 115 | if not condition: 116 | return skip(reason) 117 | return _id 118 | 119 | def expectedFailure(test_item): 120 | test_item.__unittest_expecting_failure__ = True 121 | return test_item 122 | 123 | def _is_subtype(expected, basetype): 124 | if isinstance(expected, tuple): 125 | return all(_is_subtype(e, basetype) for e in expected) 126 | return isinstance(expected, type) and issubclass(expected, basetype) 127 | 128 | class _BaseTestCaseContext: 129 | 130 | def __init__(self, test_case): 131 | self.test_case = test_case 132 | 133 | def _raiseFailure(self, standardMsg): 134 | msg = self.test_case._formatMessage(self.msg, standardMsg) 135 | raise self.test_case.failureException(msg) 136 | 137 | class _AssertRaisesBaseContext(_BaseTestCaseContext): 138 | 139 | def __init__(self, expected, test_case, expected_regex=None): 140 | _BaseTestCaseContext.__init__(self, test_case) 141 | self.expected = expected 142 | self.test_case = test_case 143 | if expected_regex is not None: 144 | expected_regex = re.compile(expected_regex) 145 | self.expected_regex = expected_regex 146 | self.obj_name = None 147 | self.msg = None 148 | 149 | def handle(self, name, args, kwargs): 150 | """ 151 | If args is empty, assertRaises/Warns is being used as a 152 | context manager, so check for a 'msg' kwarg and return self. 153 | If args is not empty, call a callable passing positional and keyword 154 | arguments. 155 | """ 156 | try: 157 | if not _is_subtype(self.expected, self._base_type): 158 | raise TypeError('%s() arg 1 must be %s' % 159 | (name, self._base_type_str)) 160 | if args and args[0] is None: 161 | warnings.warn("callable is None", 162 | DeprecationWarning, 3) 163 | args = () 164 | if not args: 165 | self.msg = kwargs.pop('msg', None) 166 | if kwargs: 167 | warnings.warn('%r is an invalid keyword argument for ' 168 | 'this function' % next(iter(kwargs)), 169 | DeprecationWarning, 3) 170 | return self 171 | 172 | callable_obj, *args = args 173 | try: 174 | self.obj_name = callable_obj.__name__ 175 | except AttributeError: 176 | self.obj_name = str(callable_obj) 177 | with self: 178 | callable_obj(*args, **kwargs) 179 | finally: 180 | # bpo-23890: manually break a reference cycle 181 | self = None 182 | 183 | 184 | class _AssertRaisesContext(_AssertRaisesBaseContext): 185 | """A context manager used to implement TestCase.assertRaises* methods.""" 186 | 187 | _base_type = BaseException 188 | _base_type_str = 'an exception type or tuple of exception types' 189 | 190 | def __enter__(self): 191 | return self 192 | 193 | def __exit__(self, exc_type, exc_value, tb): 194 | if exc_type is None: 195 | try: 196 | exc_name = self.expected.__name__ 197 | except AttributeError: 198 | exc_name = str(self.expected) 199 | if self.obj_name: 200 | self._raiseFailure("{} not raised by {}".format(exc_name, 201 | self.obj_name)) 202 | else: 203 | self._raiseFailure("{} not raised".format(exc_name)) 204 | else: 205 | traceback.clear_frames(tb) 206 | if not issubclass(exc_type, self.expected): 207 | # let unexpected exceptions pass through 208 | return False 209 | # store exception, without traceback, for later retrieval 210 | self.exception = exc_value.with_traceback(None) 211 | if self.expected_regex is None: 212 | return True 213 | 214 | expected_regex = self.expected_regex 215 | if not expected_regex.search(str(exc_value)): 216 | self._raiseFailure('"{}" does not match "{}"'.format( 217 | expected_regex.pattern, str(exc_value))) 218 | return True 219 | 220 | 221 | class _AssertWarnsContext(_AssertRaisesBaseContext): 222 | """A context manager used to implement TestCase.assertWarns* methods.""" 223 | 224 | _base_type = Warning 225 | _base_type_str = 'a warning type or tuple of warning types' 226 | 227 | def __enter__(self): 228 | # The __warningregistry__'s need to be in a pristine state for tests 229 | # to work properly. 230 | for v in sys.modules.values(): 231 | if getattr(v, '__warningregistry__', None): 232 | v.__warningregistry__ = {} 233 | self.warnings_manager = warnings.catch_warnings(record=True) 234 | self.warnings = self.warnings_manager.__enter__() 235 | warnings.simplefilter("always", self.expected) 236 | return self 237 | 238 | def __exit__(self, exc_type, exc_value, tb): 239 | self.warnings_manager.__exit__(exc_type, exc_value, tb) 240 | if exc_type is not None: 241 | # let unexpected exceptions pass through 242 | return 243 | try: 244 | exc_name = self.expected.__name__ 245 | except AttributeError: 246 | exc_name = str(self.expected) 247 | first_matching = None 248 | for m in self.warnings: 249 | w = m.message 250 | if not isinstance(w, self.expected): 251 | continue 252 | if first_matching is None: 253 | first_matching = w 254 | if (self.expected_regex is not None and 255 | not self.expected_regex.search(str(w))): 256 | continue 257 | # store warning for later retrieval 258 | self.warning = w 259 | self.filename = m.filename 260 | self.lineno = m.lineno 261 | return 262 | # Now we simply try to choose a helpful failure message 263 | if first_matching is not None: 264 | self._raiseFailure('"{}" does not match "{}"'.format( 265 | self.expected_regex.pattern, str(first_matching))) 266 | if self.obj_name: 267 | self._raiseFailure("{} not triggered by {}".format(exc_name, 268 | self.obj_name)) 269 | else: 270 | self._raiseFailure("{} not triggered".format(exc_name)) 271 | 272 | 273 | 274 | _LoggingWatcher = collections.namedtuple("_LoggingWatcher", 275 | ["records", "output"]) 276 | 277 | 278 | class _CapturingHandler(logging.Handler): 279 | """ 280 | A logging handler capturing all (raw and formatted) logging output. 281 | """ 282 | 283 | def __init__(self): 284 | logging.Handler.__init__(self) 285 | self.watcher = _LoggingWatcher([], []) 286 | 287 | def flush(self): 288 | pass 289 | 290 | def emit(self, record): 291 | self.watcher.records.append(record) 292 | msg = self.format(record) 293 | self.watcher.output.append(msg) 294 | 295 | 296 | 297 | class _AssertLogsContext(_BaseTestCaseContext): 298 | """A context manager used to implement TestCase.assertLogs().""" 299 | 300 | LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s" 301 | 302 | def __init__(self, test_case, logger_name, level): 303 | _BaseTestCaseContext.__init__(self, test_case) 304 | self.logger_name = logger_name 305 | if level: 306 | self.level = logging._nameToLevel.get(level, level) 307 | else: 308 | self.level = logging.INFO 309 | self.msg = None 310 | 311 | def __enter__(self): 312 | if isinstance(self.logger_name, logging.Logger): 313 | logger = self.logger = self.logger_name 314 | else: 315 | logger = self.logger = logging.getLogger(self.logger_name) 316 | formatter = logging.Formatter(self.LOGGING_FORMAT) 317 | handler = _CapturingHandler() 318 | handler.setFormatter(formatter) 319 | self.watcher = handler.watcher 320 | self.old_handlers = logger.handlers[:] 321 | self.old_level = logger.level 322 | self.old_propagate = logger.propagate 323 | logger.handlers = [handler] 324 | logger.setLevel(self.level) 325 | logger.propagate = False 326 | return handler.watcher 327 | 328 | def __exit__(self, exc_type, exc_value, tb): 329 | self.logger.handlers = self.old_handlers 330 | self.logger.propagate = self.old_propagate 331 | self.logger.setLevel(self.old_level) 332 | if exc_type is not None: 333 | # let unexpected exceptions pass through 334 | return False 335 | if len(self.watcher.records) == 0: 336 | self._raiseFailure( 337 | "no logs of level {} or higher triggered on {}" 338 | .format(logging.getLevelName(self.level), self.logger.name)) 339 | 340 | 341 | class TestCase(object): 342 | """A class whose instances are single test cases. 343 | 344 | By default, the test code itself should be placed in a method named 345 | 'runTest'. 346 | 347 | If the fixture may be used for many test cases, create as 348 | many test methods as are needed. When instantiating such a TestCase 349 | subclass, specify in the constructor arguments the name of the test method 350 | that the instance is to execute. 351 | 352 | Test authors should subclass TestCase for their own tests. Construction 353 | and deconstruction of the test's environment ('fixture') can be 354 | implemented by overriding the 'setUp' and 'tearDown' methods respectively. 355 | 356 | If it is necessary to override the __init__ method, the base class 357 | __init__ method must always be called. It is important that subclasses 358 | should not change the signature of their __init__ method, since instances 359 | of the classes are instantiated automatically by parts of the framework 360 | in order to be run. 361 | 362 | When subclassing TestCase, you can set these attributes: 363 | * failureException: determines which exception will be raised when 364 | the instance's assertion methods fail; test methods raising this 365 | exception will be deemed to have 'failed' rather than 'errored'. 366 | * longMessage: determines whether long messages (including repr of 367 | objects used in assert methods) will be printed on failure in *addition* 368 | to any explicit message passed. 369 | * maxDiff: sets the maximum length of a diff in failure messages 370 | by assert methods using difflib. It is looked up as an instance 371 | attribute so can be configured by individual tests if required. 372 | """ 373 | 374 | failureException = AssertionError 375 | 376 | longMessage = True 377 | 378 | maxDiff = 80*8 379 | 380 | # If a string is longer than _diffThreshold, use normal comparison instead 381 | # of difflib. See #11763. 382 | _diffThreshold = 2**16 383 | 384 | # Attribute used by TestSuite for classSetUp 385 | 386 | _classSetupFailed = False 387 | 388 | def __init__(self, methodName='runTest'): 389 | """Create an instance of the class that will use the named test 390 | method when executed. Raises a ValueError if the instance does 391 | not have a method with the specified name. 392 | """ 393 | self._testMethodName = methodName 394 | self._outcome = None 395 | self._testMethodDoc = 'No test' 396 | try: 397 | testMethod = getattr(self, methodName) 398 | except AttributeError: 399 | if methodName != 'runTest': 400 | # we allow instantiation with no explicit method name 401 | # but not an *incorrect* or missing method name 402 | raise ValueError("no such test method in %s: %s" % 403 | (self.__class__, methodName)) 404 | else: 405 | self._testMethodDoc = testMethod.__doc__ 406 | self._cleanups = [] 407 | self._subtest = None 408 | 409 | # Map types to custom assertEqual functions that will compare 410 | # instances of said type in more detail to generate a more useful 411 | # error message. 412 | self._type_equality_funcs = {} 413 | self.addTypeEqualityFunc(dict, 'assertDictEqual') 414 | self.addTypeEqualityFunc(list, 'assertListEqual') 415 | self.addTypeEqualityFunc(tuple, 'assertTupleEqual') 416 | self.addTypeEqualityFunc(set, 'assertSetEqual') 417 | self.addTypeEqualityFunc(frozenset, 'assertSetEqual') 418 | self.addTypeEqualityFunc(str, 'assertMultiLineEqual') 419 | 420 | def addTypeEqualityFunc(self, typeobj, function): 421 | """Add a type specific assertEqual style function to compare a type. 422 | 423 | This method is for use by TestCase subclasses that need to register 424 | their own type equality functions to provide nicer error messages. 425 | 426 | Args: 427 | typeobj: The data type to call this function on when both values 428 | are of the same type in assertEqual(). 429 | function: The callable taking two arguments and an optional 430 | msg= argument that raises self.failureException with a 431 | useful error message when the two arguments are not equal. 432 | """ 433 | self._type_equality_funcs[typeobj] = function 434 | 435 | def addCleanup(self, function, *args, **kwargs): 436 | """Add a function, with arguments, to be called when the test is 437 | completed. Functions added are called on a LIFO basis and are 438 | called after tearDown on test failure or success. 439 | 440 | Cleanup items are called even if setUp fails (unlike tearDown).""" 441 | self._cleanups.append((function, args, kwargs)) 442 | 443 | def setUp(self): 444 | "Hook method for setting up the test fixture before exercising it." 445 | pass 446 | 447 | def tearDown(self): 448 | "Hook method for deconstructing the test fixture after testing it." 449 | pass 450 | 451 | @classmethod 452 | def setUpClass(cls): 453 | "Hook method for setting up class fixture before running tests in the class." 454 | 455 | @classmethod 456 | def tearDownClass(cls): 457 | "Hook method for deconstructing the class fixture after running all tests in the class." 458 | 459 | def countTestCases(self): 460 | return 1 461 | 462 | def defaultTestResult(self): 463 | return result.TestResult() 464 | 465 | def shortDescription(self): 466 | """Returns a one-line description of the test, or None if no 467 | description has been provided. 468 | 469 | The default implementation of this method returns the first line of 470 | the specified test method's docstring. 471 | """ 472 | doc = self._testMethodDoc 473 | return doc and doc.split("\n")[0].strip() or None 474 | 475 | 476 | def id(self): 477 | return "%s.%s" % (strclass(self.__class__), self._testMethodName) 478 | 479 | def __eq__(self, other): 480 | if type(self) is not type(other): 481 | return NotImplemented 482 | 483 | return self._testMethodName == other._testMethodName 484 | 485 | def __hash__(self): 486 | return hash((type(self), self._testMethodName)) 487 | 488 | def __str__(self): 489 | return "%s (%s)" % (self._testMethodName, strclass(self.__class__)) 490 | 491 | def __repr__(self): 492 | return "<%s testMethod=%s>" % \ 493 | (strclass(self.__class__), self._testMethodName) 494 | 495 | def _addSkip(self, result, test_case, reason): 496 | addSkip = getattr(result, 'addSkip', None) 497 | if addSkip is not None: 498 | addSkip(test_case, reason) 499 | else: 500 | warnings.warn("TestResult has no addSkip method, skips not reported", 501 | RuntimeWarning, 2) 502 | result.addSuccess(test_case) 503 | 504 | @contextlib.contextmanager 505 | def subTest(self, msg=_subtest_msg_sentinel, **params): 506 | """Return a context manager that will return the enclosed block 507 | of code in a subtest identified by the optional message and 508 | keyword parameters. A failure in the subtest marks the test 509 | case as failed but resumes execution at the end of the enclosed 510 | block, allowing further test code to be executed. 511 | """ 512 | if not self._outcome.result_supports_subtests: 513 | yield 514 | return 515 | parent = self._subtest 516 | if parent is None: 517 | params_map = collections.ChainMap(params) 518 | else: 519 | params_map = parent.params.new_child(params) 520 | self._subtest = _SubTest(self, msg, params_map) 521 | try: 522 | with self._outcome.testPartExecutor(self._subtest, isTest=True): 523 | yield 524 | if not self._outcome.success: 525 | result = self._outcome.result 526 | if result is not None and result.failfast: 527 | raise _ShouldStop 528 | elif self._outcome.expectedFailure: 529 | # If the test is expecting a failure, we really want to 530 | # stop now and register the expected failure. 531 | raise _ShouldStop 532 | finally: 533 | self._subtest = parent 534 | 535 | def _feedErrorsToResult(self, result, errors): 536 | for test, exc_info in errors: 537 | if isinstance(test, _SubTest): 538 | result.addSubTest(test.test_case, test, exc_info) 539 | elif exc_info is not None: 540 | if issubclass(exc_info[0], self.failureException): 541 | result.addFailure(test, exc_info) 542 | else: 543 | result.addError(test, exc_info) 544 | 545 | def _addExpectedFailure(self, result, exc_info): 546 | try: 547 | addExpectedFailure = result.addExpectedFailure 548 | except AttributeError: 549 | warnings.warn("TestResult has no addExpectedFailure method, reporting as passes", 550 | RuntimeWarning) 551 | result.addSuccess(self) 552 | else: 553 | addExpectedFailure(self, exc_info) 554 | 555 | def _addUnexpectedSuccess(self, result): 556 | try: 557 | addUnexpectedSuccess = result.addUnexpectedSuccess 558 | except AttributeError: 559 | warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failure", 560 | RuntimeWarning) 561 | # We need to pass an actual exception and traceback to addFailure, 562 | # otherwise the legacy result can choke. 563 | try: 564 | raise _UnexpectedSuccess from None 565 | except _UnexpectedSuccess: 566 | result.addFailure(self, sys.exc_info()) 567 | else: 568 | addUnexpectedSuccess(self) 569 | 570 | def run(self, result=None): 571 | orig_result = result 572 | if result is None: 573 | result = self.defaultTestResult() 574 | startTestRun = getattr(result, 'startTestRun', None) 575 | if startTestRun is not None: 576 | startTestRun() 577 | 578 | result.startTest(self) 579 | 580 | testMethod = getattr(self, self._testMethodName) 581 | if (getattr(self.__class__, "__unittest_skip__", False) or 582 | getattr(testMethod, "__unittest_skip__", False)): 583 | # If the class or method was skipped. 584 | try: 585 | skip_why = (getattr(self.__class__, '__unittest_skip_why__', '') 586 | or getattr(testMethod, '__unittest_skip_why__', '')) 587 | self._addSkip(result, self, skip_why) 588 | finally: 589 | result.stopTest(self) 590 | return 591 | expecting_failure_method = getattr(testMethod, 592 | "__unittest_expecting_failure__", False) 593 | expecting_failure_class = getattr(self, 594 | "__unittest_expecting_failure__", False) 595 | expecting_failure = expecting_failure_class or expecting_failure_method 596 | outcome = _Outcome(result) 597 | try: 598 | self._outcome = outcome 599 | 600 | with outcome.testPartExecutor(self): 601 | self.setUp() 602 | if outcome.success: 603 | outcome.expecting_failure = expecting_failure 604 | with outcome.testPartExecutor(self, isTest=True): 605 | testMethod() 606 | outcome.expecting_failure = False 607 | with outcome.testPartExecutor(self): 608 | self.tearDown() 609 | 610 | self.doCleanups() 611 | for test, reason in outcome.skipped: 612 | self._addSkip(result, test, reason) 613 | self._feedErrorsToResult(result, outcome.errors) 614 | if outcome.success: 615 | if expecting_failure: 616 | if outcome.expectedFailure: 617 | self._addExpectedFailure(result, outcome.expectedFailure) 618 | else: 619 | self._addUnexpectedSuccess(result) 620 | else: 621 | result.addSuccess(self) 622 | return result 623 | finally: 624 | result.stopTest(self) 625 | if orig_result is None: 626 | stopTestRun = getattr(result, 'stopTestRun', None) 627 | if stopTestRun is not None: 628 | stopTestRun() 629 | 630 | # explicitly break reference cycles: 631 | # outcome.errors -> frame -> outcome -> outcome.errors 632 | # outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure 633 | outcome.errors.clear() 634 | outcome.expectedFailure = None 635 | 636 | # clear the outcome, no more needed 637 | self._outcome = None 638 | 639 | def doCleanups(self): 640 | """Execute all cleanup functions. Normally called for you after 641 | tearDown.""" 642 | outcome = self._outcome or _Outcome() 643 | while self._cleanups: 644 | function, args, kwargs = self._cleanups.pop() 645 | with outcome.testPartExecutor(self): 646 | function(*args, **kwargs) 647 | 648 | # return this for backwards compatibility 649 | # even though we no longer us it internally 650 | return outcome.success 651 | 652 | def __call__(self, *args, **kwds): 653 | return self.run(*args, **kwds) 654 | 655 | def debug(self): 656 | """Run the test without collecting errors in a TestResult""" 657 | self.setUp() 658 | getattr(self, self._testMethodName)() 659 | self.tearDown() 660 | while self._cleanups: 661 | function, args, kwargs = self._cleanups.pop(-1) 662 | function(*args, **kwargs) 663 | 664 | def skipTest(self, reason): 665 | """Skip this test.""" 666 | raise SkipTest(reason) 667 | 668 | def fail(self, msg=None): 669 | """Fail immediately, with the given message.""" 670 | raise self.failureException(msg) 671 | 672 | def assertFalse(self, expr, msg=None): 673 | """Check that the expression is false.""" 674 | if expr: 675 | msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr)) 676 | raise self.failureException(msg) 677 | 678 | def assertTrue(self, expr, msg=None): 679 | """Check that the expression is true.""" 680 | if not expr: 681 | msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr)) 682 | raise self.failureException(msg) 683 | 684 | def _formatMessage(self, msg, standardMsg): 685 | """Honour the longMessage attribute when generating failure messages. 686 | If longMessage is False this means: 687 | * Use only an explicit message if it is provided 688 | * Otherwise use the standard message for the assert 689 | 690 | If longMessage is True: 691 | * Use the standard message 692 | * If an explicit message is provided, plus ' : ' and the explicit message 693 | """ 694 | if not self.longMessage: 695 | return msg or standardMsg 696 | if msg is None: 697 | return standardMsg 698 | try: 699 | # don't switch to '{}' formatting in Python 2.X 700 | # it changes the way unicode input is handled 701 | return '%s : %s' % (standardMsg, msg) 702 | except UnicodeDecodeError: 703 | return '%s : %s' % (safe_repr(standardMsg), safe_repr(msg)) 704 | 705 | def assertRaises(self, expected_exception, *args, **kwargs): 706 | """Fail unless an exception of class expected_exception is raised 707 | by the callable when invoked with specified positional and 708 | keyword arguments. If a different type of exception is 709 | raised, it will not be caught, and the test case will be 710 | deemed to have suffered an error, exactly as for an 711 | unexpected exception. 712 | 713 | If called with the callable and arguments omitted, will return a 714 | context object used like this:: 715 | 716 | with self.assertRaises(SomeException): 717 | do_something() 718 | 719 | An optional keyword argument 'msg' can be provided when assertRaises 720 | is used as a context object. 721 | 722 | The context manager keeps a reference to the exception as 723 | the 'exception' attribute. This allows you to inspect the 724 | exception after the assertion:: 725 | 726 | with self.assertRaises(SomeException) as cm: 727 | do_something() 728 | the_exception = cm.exception 729 | self.assertEqual(the_exception.error_code, 3) 730 | """ 731 | context = _AssertRaisesContext(expected_exception, self) 732 | try: 733 | return context.handle('assertRaises', args, kwargs) 734 | finally: 735 | # bpo-23890: manually break a reference cycle 736 | context = None 737 | 738 | def assertWarns(self, expected_warning, *args, **kwargs): 739 | """Fail unless a warning of class warnClass is triggered 740 | by the callable when invoked with specified positional and 741 | keyword arguments. If a different type of warning is 742 | triggered, it will not be handled: depending on the other 743 | warning filtering rules in effect, it might be silenced, printed 744 | out, or raised as an exception. 745 | 746 | If called with the callable and arguments omitted, will return a 747 | context object used like this:: 748 | 749 | with self.assertWarns(SomeWarning): 750 | do_something() 751 | 752 | An optional keyword argument 'msg' can be provided when assertWarns 753 | is used as a context object. 754 | 755 | The context manager keeps a reference to the first matching 756 | warning as the 'warning' attribute; similarly, the 'filename' 757 | and 'lineno' attributes give you information about the line 758 | of Python code from which the warning was triggered. 759 | This allows you to inspect the warning after the assertion:: 760 | 761 | with self.assertWarns(SomeWarning) as cm: 762 | do_something() 763 | the_warning = cm.warning 764 | self.assertEqual(the_warning.some_attribute, 147) 765 | """ 766 | context = _AssertWarnsContext(expected_warning, self) 767 | return context.handle('assertWarns', args, kwargs) 768 | 769 | def assertLogs(self, logger=None, level=None): 770 | """Fail unless a log message of level *level* or higher is emitted 771 | on *logger_name* or its children. If omitted, *level* defaults to 772 | INFO and *logger* defaults to the root logger. 773 | 774 | This method must be used as a context manager, and will yield 775 | a recording object with two attributes: `output` and `records`. 776 | At the end of the context manager, the `output` attribute will 777 | be a list of the matching formatted log messages and the 778 | `records` attribute will be a list of the corresponding LogRecord 779 | objects. 780 | 781 | Example:: 782 | 783 | with self.assertLogs('foo', level='INFO') as cm: 784 | logging.getLogger('foo').info('first message') 785 | logging.getLogger('foo.bar').error('second message') 786 | self.assertEqual(cm.output, ['INFO:foo:first message', 787 | 'ERROR:foo.bar:second message']) 788 | """ 789 | return _AssertLogsContext(self, logger, level) 790 | 791 | def _getAssertEqualityFunc(self, first, second): 792 | """Get a detailed comparison function for the types of the two args. 793 | 794 | Returns: A callable accepting (first, second, msg=None) that will 795 | raise a failure exception if first != second with a useful human 796 | readable error message for those types. 797 | """ 798 | # 799 | # NOTE(gregory.p.smith): I considered isinstance(first, type(second)) 800 | # and vice versa. I opted for the conservative approach in case 801 | # subclasses are not intended to be compared in detail to their super 802 | # class instances using a type equality func. This means testing 803 | # subtypes won't automagically use the detailed comparison. Callers 804 | # should use their type specific assertSpamEqual method to compare 805 | # subclasses if the detailed comparison is desired and appropriate. 806 | # See the discussion in http://bugs.python.org/issue2578. 807 | # 808 | if type(first) is type(second): 809 | asserter = self._type_equality_funcs.get(type(first)) 810 | if asserter is not None: 811 | if isinstance(asserter, str): 812 | asserter = getattr(self, asserter) 813 | return asserter 814 | 815 | return self._baseAssertEqual 816 | 817 | def _baseAssertEqual(self, first, second, msg=None): 818 | """The default assertEqual implementation, not type specific.""" 819 | if not first == second: 820 | standardMsg = '%s != %s' % _common_shorten_repr(first, second) 821 | msg = self._formatMessage(msg, standardMsg) 822 | raise self.failureException(msg) 823 | 824 | def assertEqual(self, first, second, msg=None): 825 | """Fail if the two objects are unequal as determined by the '==' 826 | operator. 827 | """ 828 | assertion_func = self._getAssertEqualityFunc(first, second) 829 | assertion_func(first, second, msg=msg) 830 | 831 | def assertNotEqual(self, first, second, msg=None): 832 | """Fail if the two objects are equal as determined by the '!=' 833 | operator. 834 | """ 835 | if not first != second: 836 | msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first), 837 | safe_repr(second))) 838 | raise self.failureException(msg) 839 | 840 | def assertAlmostEqual(self, first, second, places=None, msg=None, 841 | delta=None): 842 | """Fail if the two objects are unequal as determined by their 843 | difference rounded to the given number of decimal places 844 | (default 7) and comparing to zero, or by comparing that the 845 | between the two objects is more than the given delta. 846 | 847 | Note that decimal places (from zero) are usually not the same 848 | as significant digits (measured from the most significant digit). 849 | 850 | If the two objects compare equal then they will automatically 851 | compare almost equal. 852 | """ 853 | if first == second: 854 | # shortcut 855 | return 856 | if delta is not None and places is not None: 857 | raise TypeError("specify delta or places not both") 858 | 859 | if delta is not None: 860 | if abs(first - second) <= delta: 861 | return 862 | 863 | standardMsg = '%s != %s within %s delta' % (safe_repr(first), 864 | safe_repr(second), 865 | safe_repr(delta)) 866 | else: 867 | if places is None: 868 | places = 7 869 | 870 | if round(abs(second-first), places) == 0: 871 | return 872 | 873 | standardMsg = '%s != %s within %r places' % (safe_repr(first), 874 | safe_repr(second), 875 | places) 876 | msg = self._formatMessage(msg, standardMsg) 877 | raise self.failureException(msg) 878 | 879 | def assertNotAlmostEqual(self, first, second, places=None, msg=None, 880 | delta=None): 881 | """Fail if the two objects are equal as determined by their 882 | difference rounded to the given number of decimal places 883 | (default 7) and comparing to zero, or by comparing that the 884 | between the two objects is less than the given delta. 885 | 886 | Note that decimal places (from zero) are usually not the same 887 | as significant digits (measured from the most significant digit). 888 | 889 | Objects that are equal automatically fail. 890 | """ 891 | if delta is not None and places is not None: 892 | raise TypeError("specify delta or places not both") 893 | if delta is not None: 894 | if not (first == second) and abs(first - second) > delta: 895 | return 896 | standardMsg = '%s == %s within %s delta' % (safe_repr(first), 897 | safe_repr(second), 898 | safe_repr(delta)) 899 | else: 900 | if places is None: 901 | places = 7 902 | if not (first == second) and round(abs(second-first), places) != 0: 903 | return 904 | standardMsg = '%s == %s within %r places' % (safe_repr(first), 905 | safe_repr(second), 906 | places) 907 | 908 | msg = self._formatMessage(msg, standardMsg) 909 | raise self.failureException(msg) 910 | 911 | 912 | def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None): 913 | """An equality assertion for ordered sequences (like lists and tuples). 914 | 915 | For the purposes of this function, a valid ordered sequence type is one 916 | which can be indexed, has a length, and has an equality operator. 917 | 918 | Args: 919 | seq1: The first sequence to compare. 920 | seq2: The second sequence to compare. 921 | seq_type: The expected datatype of the sequences, or None if no 922 | datatype should be enforced. 923 | msg: Optional message to use on failure instead of a list of 924 | differences. 925 | """ 926 | if seq_type is not None: 927 | seq_type_name = seq_type.__name__ 928 | if not isinstance(seq1, seq_type): 929 | raise self.failureException('First sequence is not a %s: %s' 930 | % (seq_type_name, safe_repr(seq1))) 931 | if not isinstance(seq2, seq_type): 932 | raise self.failureException('Second sequence is not a %s: %s' 933 | % (seq_type_name, safe_repr(seq2))) 934 | else: 935 | seq_type_name = "sequence" 936 | 937 | differing = None 938 | try: 939 | len1 = len(seq1) 940 | except (TypeError, NotImplementedError): 941 | differing = 'First %s has no length. Non-sequence?' % ( 942 | seq_type_name) 943 | 944 | if differing is None: 945 | try: 946 | len2 = len(seq2) 947 | except (TypeError, NotImplementedError): 948 | differing = 'Second %s has no length. Non-sequence?' % ( 949 | seq_type_name) 950 | 951 | if differing is None: 952 | if seq1 == seq2: 953 | return 954 | 955 | differing = '%ss differ: %s != %s\n' % ( 956 | (seq_type_name.capitalize(),) + 957 | _common_shorten_repr(seq1, seq2)) 958 | 959 | for i in range(min(len1, len2)): 960 | try: 961 | item1 = seq1[i] 962 | except (TypeError, IndexError, NotImplementedError): 963 | differing += ('\nUnable to index element %d of first %s\n' % 964 | (i, seq_type_name)) 965 | break 966 | 967 | try: 968 | item2 = seq2[i] 969 | except (TypeError, IndexError, NotImplementedError): 970 | differing += ('\nUnable to index element %d of second %s\n' % 971 | (i, seq_type_name)) 972 | break 973 | 974 | if item1 != item2: 975 | differing += ('\nFirst differing element %d:\n%s\n%s\n' % 976 | ((i,) + _common_shorten_repr(item1, item2))) 977 | break 978 | else: 979 | if (len1 == len2 and seq_type is None and 980 | type(seq1) != type(seq2)): 981 | # The sequences are the same, but have differing types. 982 | return 983 | 984 | if len1 > len2: 985 | differing += ('\nFirst %s contains %d additional ' 986 | 'elements.\n' % (seq_type_name, len1 - len2)) 987 | try: 988 | differing += ('First extra element %d:\n%s\n' % 989 | (len2, safe_repr(seq1[len2]))) 990 | except (TypeError, IndexError, NotImplementedError): 991 | differing += ('Unable to index element %d ' 992 | 'of first %s\n' % (len2, seq_type_name)) 993 | elif len1 < len2: 994 | differing += ('\nSecond %s contains %d additional ' 995 | 'elements.\n' % (seq_type_name, len2 - len1)) 996 | try: 997 | differing += ('First extra element %d:\n%s\n' % 998 | (len1, safe_repr(seq2[len1]))) 999 | except (TypeError, IndexError, NotImplementedError): 1000 | differing += ('Unable to index element %d ' 1001 | 'of second %s\n' % (len1, seq_type_name)) 1002 | standardMsg = differing 1003 | diffMsg = '\n' + '\n'.join( 1004 | difflib.ndiff(pprint.pformat(seq1).splitlines(), 1005 | pprint.pformat(seq2).splitlines())) 1006 | 1007 | standardMsg = self._truncateMessage(standardMsg, diffMsg) 1008 | msg = self._formatMessage(msg, standardMsg) 1009 | self.fail(msg) 1010 | 1011 | def _truncateMessage(self, message, diff): 1012 | max_diff = self.maxDiff 1013 | if max_diff is None or len(diff) <= max_diff: 1014 | return message + diff 1015 | return message + (DIFF_OMITTED % len(diff)) 1016 | 1017 | def assertListEqual(self, list1, list2, msg=None): 1018 | """A list-specific equality assertion. 1019 | 1020 | Args: 1021 | list1: The first list to compare. 1022 | list2: The second list to compare. 1023 | msg: Optional message to use on failure instead of a list of 1024 | differences. 1025 | 1026 | """ 1027 | self.assertSequenceEqual(list1, list2, msg, seq_type=list) 1028 | 1029 | def assertTupleEqual(self, tuple1, tuple2, msg=None): 1030 | """A tuple-specific equality assertion. 1031 | 1032 | Args: 1033 | tuple1: The first tuple to compare. 1034 | tuple2: The second tuple to compare. 1035 | msg: Optional message to use on failure instead of a list of 1036 | differences. 1037 | """ 1038 | self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple) 1039 | 1040 | def assertSetEqual(self, set1, set2, msg=None): 1041 | """A set-specific equality assertion. 1042 | 1043 | Args: 1044 | set1: The first set to compare. 1045 | set2: The second set to compare. 1046 | msg: Optional message to use on failure instead of a list of 1047 | differences. 1048 | 1049 | assertSetEqual uses ducktyping to support different types of sets, and 1050 | is optimized for sets specifically (parameters must support a 1051 | difference method). 1052 | """ 1053 | try: 1054 | difference1 = set1.difference(set2) 1055 | except TypeError as e: 1056 | self.fail('invalid type when attempting set difference: %s' % e) 1057 | except AttributeError as e: 1058 | self.fail('first argument does not support set difference: %s' % e) 1059 | 1060 | try: 1061 | difference2 = set2.difference(set1) 1062 | except TypeError as e: 1063 | self.fail('invalid type when attempting set difference: %s' % e) 1064 | except AttributeError as e: 1065 | self.fail('second argument does not support set difference: %s' % e) 1066 | 1067 | if not (difference1 or difference2): 1068 | return 1069 | 1070 | lines = [] 1071 | if difference1: 1072 | lines.append('Items in the first set but not the second:') 1073 | for item in difference1: 1074 | lines.append(repr(item)) 1075 | if difference2: 1076 | lines.append('Items in the second set but not the first:') 1077 | for item in difference2: 1078 | lines.append(repr(item)) 1079 | 1080 | standardMsg = '\n'.join(lines) 1081 | self.fail(self._formatMessage(msg, standardMsg)) 1082 | 1083 | def assertIn(self, member, container, msg=None): 1084 | """Just like self.assertTrue(a in b), but with a nicer default message.""" 1085 | if member not in container: 1086 | standardMsg = '%s not found in %s' % (safe_repr(member), 1087 | safe_repr(container)) 1088 | self.fail(self._formatMessage(msg, standardMsg)) 1089 | 1090 | def assertNotIn(self, member, container, msg=None): 1091 | """Just like self.assertTrue(a not in b), but with a nicer default message.""" 1092 | if member in container: 1093 | standardMsg = '%s unexpectedly found in %s' % (safe_repr(member), 1094 | safe_repr(container)) 1095 | self.fail(self._formatMessage(msg, standardMsg)) 1096 | 1097 | def assertIs(self, expr1, expr2, msg=None): 1098 | """Just like self.assertTrue(a is b), but with a nicer default message.""" 1099 | if expr1 is not expr2: 1100 | standardMsg = '%s is not %s' % (safe_repr(expr1), 1101 | safe_repr(expr2)) 1102 | self.fail(self._formatMessage(msg, standardMsg)) 1103 | 1104 | def assertIsNot(self, expr1, expr2, msg=None): 1105 | """Just like self.assertTrue(a is not b), but with a nicer default message.""" 1106 | if expr1 is expr2: 1107 | standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),) 1108 | self.fail(self._formatMessage(msg, standardMsg)) 1109 | 1110 | def assertDictEqual(self, d1, d2, msg=None): 1111 | self.assertIsInstance(d1, dict, 'First argument is not a dictionary') 1112 | self.assertIsInstance(d2, dict, 'Second argument is not a dictionary') 1113 | 1114 | if d1 != d2: 1115 | standardMsg = '%s != %s' % _common_shorten_repr(d1, d2) 1116 | diff = ('\n' + '\n'.join(difflib.ndiff( 1117 | pprint.pformat(d1).splitlines(), 1118 | pprint.pformat(d2).splitlines()))) 1119 | standardMsg = self._truncateMessage(standardMsg, diff) 1120 | self.fail(self._formatMessage(msg, standardMsg)) 1121 | 1122 | def assertDictContainsSubset(self, subset, dictionary, msg=None): 1123 | """Checks whether dictionary is a superset of subset.""" 1124 | warnings.warn('assertDictContainsSubset is deprecated', 1125 | DeprecationWarning) 1126 | missing = [] 1127 | mismatched = [] 1128 | for key, value in subset.items(): 1129 | if key not in dictionary: 1130 | missing.append(key) 1131 | elif value != dictionary[key]: 1132 | mismatched.append('%s, expected: %s, actual: %s' % 1133 | (safe_repr(key), safe_repr(value), 1134 | safe_repr(dictionary[key]))) 1135 | 1136 | if not (missing or mismatched): 1137 | return 1138 | 1139 | standardMsg = '' 1140 | if missing: 1141 | standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in 1142 | missing) 1143 | if mismatched: 1144 | if standardMsg: 1145 | standardMsg += '; ' 1146 | standardMsg += 'Mismatched values: %s' % ','.join(mismatched) 1147 | 1148 | self.fail(self._formatMessage(msg, standardMsg)) 1149 | 1150 | 1151 | def assertCountEqual(self, first, second, msg=None): 1152 | """An unordered sequence comparison asserting that the same elements, 1153 | regardless of order. If the same element occurs more than once, 1154 | it verifies that the elements occur the same number of times. 1155 | 1156 | self.assertEqual(Counter(list(first)), 1157 | Counter(list(second))) 1158 | 1159 | Example: 1160 | - [0, 1, 1] and [1, 0, 1] compare equal. 1161 | - [0, 0, 1] and [0, 1] compare unequal. 1162 | 1163 | """ 1164 | first_seq, second_seq = list(first), list(second) 1165 | try: 1166 | first = collections.Counter(first_seq) 1167 | second = collections.Counter(second_seq) 1168 | except TypeError: 1169 | # Handle case with unhashable elements 1170 | differences = _count_diff_all_purpose(first_seq, second_seq) 1171 | else: 1172 | if first == second: 1173 | return 1174 | differences = _count_diff_hashable(first_seq, second_seq) 1175 | 1176 | if differences: 1177 | standardMsg = 'Element counts were not equal:\n' 1178 | lines = ['First has %d, Second has %d: %r' % diff for diff in differences] 1179 | diffMsg = '\n'.join(lines) 1180 | standardMsg = self._truncateMessage(standardMsg, diffMsg) 1181 | msg = self._formatMessage(msg, standardMsg) 1182 | self.fail(msg) 1183 | 1184 | def assertMultiLineEqual(self, first, second, msg=None): 1185 | """Assert that two multi-line strings are equal.""" 1186 | self.assertIsInstance(first, str, 'First argument is not a string') 1187 | self.assertIsInstance(second, str, 'Second argument is not a string') 1188 | 1189 | if first != second: 1190 | # don't use difflib if the strings are too long 1191 | if (len(first) > self._diffThreshold or 1192 | len(second) > self._diffThreshold): 1193 | self._baseAssertEqual(first, second, msg) 1194 | firstlines = first.splitlines(keepends=True) 1195 | secondlines = second.splitlines(keepends=True) 1196 | if len(firstlines) == 1 and first.strip('\r\n') == first: 1197 | firstlines = [first + '\n'] 1198 | secondlines = [second + '\n'] 1199 | standardMsg = '%s != %s' % _common_shorten_repr(first, second) 1200 | diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines)) 1201 | standardMsg = self._truncateMessage(standardMsg, diff) 1202 | self.fail(self._formatMessage(msg, standardMsg)) 1203 | 1204 | def assertLess(self, a, b, msg=None): 1205 | """Just like self.assertTrue(a < b), but with a nicer default message.""" 1206 | if not a < b: 1207 | standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b)) 1208 | self.fail(self._formatMessage(msg, standardMsg)) 1209 | 1210 | def assertLessEqual(self, a, b, msg=None): 1211 | """Just like self.assertTrue(a <= b), but with a nicer default message.""" 1212 | if not a <= b: 1213 | standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b)) 1214 | self.fail(self._formatMessage(msg, standardMsg)) 1215 | 1216 | def assertGreater(self, a, b, msg=None): 1217 | """Just like self.assertTrue(a > b), but with a nicer default message.""" 1218 | if not a > b: 1219 | standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b)) 1220 | self.fail(self._formatMessage(msg, standardMsg)) 1221 | 1222 | def assertGreaterEqual(self, a, b, msg=None): 1223 | """Just like self.assertTrue(a >= b), but with a nicer default message.""" 1224 | if not a >= b: 1225 | standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b)) 1226 | self.fail(self._formatMessage(msg, standardMsg)) 1227 | 1228 | def assertIsNone(self, obj, msg=None): 1229 | """Same as self.assertTrue(obj is None), with a nicer default message.""" 1230 | if obj is not None: 1231 | standardMsg = '%s is not None' % (safe_repr(obj),) 1232 | self.fail(self._formatMessage(msg, standardMsg)) 1233 | 1234 | def assertIsNotNone(self, obj, msg=None): 1235 | """Included for symmetry with assertIsNone.""" 1236 | if obj is None: 1237 | standardMsg = 'unexpectedly None' 1238 | self.fail(self._formatMessage(msg, standardMsg)) 1239 | 1240 | def assertIsInstance(self, obj, cls, msg=None): 1241 | """Same as self.assertTrue(isinstance(obj, cls)), with a nicer 1242 | default message.""" 1243 | if not isinstance(obj, cls): 1244 | standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls) 1245 | self.fail(self._formatMessage(msg, standardMsg)) 1246 | 1247 | def assertNotIsInstance(self, obj, cls, msg=None): 1248 | """Included for symmetry with assertIsInstance.""" 1249 | if isinstance(obj, cls): 1250 | standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls) 1251 | self.fail(self._formatMessage(msg, standardMsg)) 1252 | 1253 | def assertRaisesRegex(self, expected_exception, expected_regex, 1254 | *args, **kwargs): 1255 | """Asserts that the message in a raised exception matches a regex. 1256 | 1257 | Args: 1258 | expected_exception: Exception class expected to be raised. 1259 | expected_regex: Regex (re pattern object or string) expected 1260 | to be found in error message. 1261 | args: Function to be called and extra positional args. 1262 | kwargs: Extra kwargs. 1263 | msg: Optional message used in case of failure. Can only be used 1264 | when assertRaisesRegex is used as a context manager. 1265 | """ 1266 | context = _AssertRaisesContext(expected_exception, self, expected_regex) 1267 | return context.handle('assertRaisesRegex', args, kwargs) 1268 | 1269 | def assertWarnsRegex(self, expected_warning, expected_regex, 1270 | *args, **kwargs): 1271 | """Asserts that the message in a triggered warning matches a regexp. 1272 | Basic functioning is similar to assertWarns() with the addition 1273 | that only warnings whose messages also match the regular expression 1274 | are considered successful matches. 1275 | 1276 | Args: 1277 | expected_warning: Warning class expected to be triggered. 1278 | expected_regex: Regex (re pattern object or string) expected 1279 | to be found in error message. 1280 | args: Function to be called and extra positional args. 1281 | kwargs: Extra kwargs. 1282 | msg: Optional message used in case of failure. Can only be used 1283 | when assertWarnsRegex is used as a context manager. 1284 | """ 1285 | context = _AssertWarnsContext(expected_warning, self, expected_regex) 1286 | return context.handle('assertWarnsRegex', args, kwargs) 1287 | 1288 | def assertRegex(self, text, expected_regex, msg=None): 1289 | """Fail the test unless the text matches the regular expression.""" 1290 | if isinstance(expected_regex, (str, bytes)): 1291 | assert expected_regex, "expected_regex must not be empty." 1292 | expected_regex = re.compile(expected_regex) 1293 | if not expected_regex.search(text): 1294 | standardMsg = "Regex didn't match: %r not found in %r" % ( 1295 | expected_regex.pattern, text) 1296 | # _formatMessage ensures the longMessage option is respected 1297 | msg = self._formatMessage(msg, standardMsg) 1298 | raise self.failureException(msg) 1299 | 1300 | def assertNotRegex(self, text, unexpected_regex, msg=None): 1301 | """Fail the test if the text matches the regular expression.""" 1302 | if isinstance(unexpected_regex, (str, bytes)): 1303 | unexpected_regex = re.compile(unexpected_regex) 1304 | match = unexpected_regex.search(text) 1305 | if match: 1306 | standardMsg = 'Regex matched: %r matches %r in %r' % ( 1307 | text[match.start() : match.end()], 1308 | unexpected_regex.pattern, 1309 | text) 1310 | # _formatMessage ensures the longMessage option is respected 1311 | msg = self._formatMessage(msg, standardMsg) 1312 | raise self.failureException(msg) 1313 | 1314 | 1315 | def _deprecate(original_func): 1316 | def deprecated_func(*args, **kwargs): 1317 | warnings.warn( 1318 | 'Please use {0} instead.'.format(original_func.__name__), 1319 | DeprecationWarning, 2) 1320 | return original_func(*args, **kwargs) 1321 | return deprecated_func 1322 | 1323 | # see #9424 1324 | failUnlessEqual = assertEquals = _deprecate(assertEqual) 1325 | failIfEqual = assertNotEquals = _deprecate(assertNotEqual) 1326 | failUnlessAlmostEqual = assertAlmostEquals = _deprecate(assertAlmostEqual) 1327 | failIfAlmostEqual = assertNotAlmostEquals = _deprecate(assertNotAlmostEqual) 1328 | failUnless = assert_ = _deprecate(assertTrue) 1329 | failUnlessRaises = _deprecate(assertRaises) 1330 | failIf = _deprecate(assertFalse) 1331 | assertRaisesRegexp = _deprecate(assertRaisesRegex) 1332 | assertRegexpMatches = _deprecate(assertRegex) 1333 | assertNotRegexpMatches = _deprecate(assertNotRegex) 1334 | 1335 | 1336 | 1337 | class FunctionTestCase(TestCase): 1338 | """A test case that wraps a test function. 1339 | 1340 | This is useful for slipping pre-existing test functions into the 1341 | unittest framework. Optionally, set-up and tidy-up functions can be 1342 | supplied. As with TestCase, the tidy-up ('tearDown') function will 1343 | always be called if the set-up ('setUp') function ran successfully. 1344 | """ 1345 | 1346 | def __init__(self, testFunc, setUp=None, tearDown=None, description=None): 1347 | super(FunctionTestCase, self).__init__() 1348 | self._setUpFunc = setUp 1349 | self._tearDownFunc = tearDown 1350 | self._testFunc = testFunc 1351 | self._description = description 1352 | 1353 | def setUp(self): 1354 | if self._setUpFunc is not None: 1355 | self._setUpFunc() 1356 | 1357 | def tearDown(self): 1358 | if self._tearDownFunc is not None: 1359 | self._tearDownFunc() 1360 | 1361 | def runTest(self): 1362 | self._testFunc() 1363 | 1364 | def id(self): 1365 | return self._testFunc.__name__ 1366 | 1367 | def __eq__(self, other): 1368 | if not isinstance(other, self.__class__): 1369 | return NotImplemented 1370 | 1371 | return self._setUpFunc == other._setUpFunc and \ 1372 | self._tearDownFunc == other._tearDownFunc and \ 1373 | self._testFunc == other._testFunc and \ 1374 | self._description == other._description 1375 | 1376 | def __hash__(self): 1377 | return hash((type(self), self._setUpFunc, self._tearDownFunc, 1378 | self._testFunc, self._description)) 1379 | 1380 | def __str__(self): 1381 | return "%s (%s)" % (strclass(self.__class__), 1382 | self._testFunc.__name__) 1383 | 1384 | def __repr__(self): 1385 | return "<%s tec=%s>" % (strclass(self.__class__), 1386 | self._testFunc) 1387 | 1388 | def shortDescription(self): 1389 | if self._description is not None: 1390 | return self._description 1391 | doc = self._testFunc.__doc__ 1392 | return doc and doc.split("\n")[0].strip() or None 1393 | 1394 | 1395 | class _SubTest(TestCase): 1396 | 1397 | def __init__(self, test_case, message, params): 1398 | super().__init__() 1399 | self._message = message 1400 | self.test_case = test_case 1401 | self.params = params 1402 | self.failureException = test_case.failureException 1403 | 1404 | def runTest(self): 1405 | raise NotImplementedError("subtests cannot be run directly") 1406 | 1407 | def _subDescription(self): 1408 | parts = [] 1409 | if self._message is not _subtest_msg_sentinel: 1410 | parts.append("[{}]".format(self._message)) 1411 | if self.params: 1412 | params_desc = ', '.join( 1413 | "{}={!r}".format(k, v) 1414 | for (k, v) in sorted(self.params.items())) 1415 | parts.append("({})".format(params_desc)) 1416 | return " ".join(parts) or '()' 1417 | 1418 | def id(self): 1419 | return "{} {}".format(self.test_case.id(), self._subDescription()) 1420 | 1421 | def shortDescription(self): 1422 | """Returns a one-line description of the subtest, or None if no 1423 | description has been provided. 1424 | """ 1425 | return self.test_case.shortDescription() 1426 | 1427 | def __str__(self): 1428 | return "{} {}".format(self.test_case, self._subDescription()) 1429 | -------------------------------------------------------------------------------- /fastunit/loader.py: -------------------------------------------------------------------------------- 1 | """Loading unittests.""" 2 | 3 | import os 4 | import re 5 | import sys 6 | import traceback 7 | import types 8 | import functools 9 | import warnings 10 | 11 | from fnmatch import fnmatch 12 | 13 | from . import case, suite, util 14 | 15 | __unittest = True 16 | 17 | # what about .pyc (etc) 18 | # we would need to avoid loading the same tests multiple times 19 | # from '.py', *and* '.pyc' 20 | VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE) 21 | 22 | 23 | class _FailedTest(case.TestCase): 24 | _testMethodName = None 25 | 26 | def __init__(self, method_name, exception): 27 | self._exception = exception 28 | super(_FailedTest, self).__init__(method_name) 29 | 30 | def __getattr__(self, name): 31 | if name != self._testMethodName: 32 | return super(_FailedTest, self).__getattr__(name) 33 | def testFailure(): 34 | raise self._exception 35 | return testFailure 36 | 37 | 38 | def _make_failed_import_test(name, suiteClass): 39 | message = 'Failed to import test module: %s\n%s' % ( 40 | name, traceback.format_exc()) 41 | return _make_failed_test(name, ImportError(message), suiteClass, message) 42 | 43 | def _make_failed_load_tests(name, exception, suiteClass): 44 | message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),) 45 | return _make_failed_test( 46 | name, exception, suiteClass, message) 47 | 48 | def _make_failed_test(methodname, exception, suiteClass, message): 49 | test = _FailedTest(methodname, exception) 50 | return suiteClass((test,)), message 51 | 52 | def _make_skipped_test(methodname, exception, suiteClass): 53 | @case.skip(str(exception)) 54 | def testSkipped(self): 55 | pass 56 | attrs = {methodname: testSkipped} 57 | TestClass = type("ModuleSkipped", (case.TestCase,), attrs) 58 | return suiteClass((TestClass(methodname),)) 59 | 60 | def _jython_aware_splitext(path): 61 | if path.lower().endswith('$py.class'): 62 | return path[:-9] 63 | return os.path.splitext(path)[0] 64 | 65 | 66 | class TestLoader(object): 67 | """ 68 | This class is responsible for loading tests according to various criteria 69 | and returning them wrapped in a TestSuite 70 | """ 71 | testMethodPrefix = 'test' 72 | sortTestMethodsUsing = staticmethod(util.three_way_cmp) 73 | suiteClass = suite.TestSuite 74 | _top_level_dir = None 75 | 76 | def __init__(self): 77 | super(TestLoader, self).__init__() 78 | self.errors = [] 79 | # Tracks packages which we have called into via load_tests, to 80 | # avoid infinite re-entrancy. 81 | self._loading_packages = set() 82 | 83 | def loadTestsFromTestCase(self, testCaseClass): 84 | """Return a suite of all test cases contained in testCaseClass""" 85 | if issubclass(testCaseClass, suite.TestSuite): 86 | raise TypeError("Test cases should not be derived from " 87 | "TestSuite. Maybe you meant to derive from " 88 | "TestCase?") 89 | testCaseNames = self.getTestCaseNames(testCaseClass) 90 | if not testCaseNames and hasattr(testCaseClass, 'runTest'): 91 | testCaseNames = ['runTest'] 92 | loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames)) 93 | return loaded_suite 94 | 95 | # XXX After Python 3.5, remove backward compatibility hacks for 96 | # use_load_tests deprecation via *args and **kws. See issue 16662. 97 | def loadTestsFromModule(self, module, *args, pattern=None, **kws): 98 | """Return a suite of all test cases contained in the given module""" 99 | # This method used to take an undocumented and unofficial 100 | # use_load_tests argument. For backward compatibility, we still 101 | # accept the argument (which can also be the first position) but we 102 | # ignore it and issue a deprecation warning if it's present. 103 | if len(args) > 0 or 'use_load_tests' in kws: 104 | warnings.warn('use_load_tests is deprecated and ignored', 105 | DeprecationWarning) 106 | kws.pop('use_load_tests', None) 107 | if len(args) > 1: 108 | # Complain about the number of arguments, but don't forget the 109 | # required `module` argument. 110 | complaint = len(args) + 1 111 | raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint)) 112 | if len(kws) != 0: 113 | # Since the keyword arguments are unsorted (see PEP 468), just 114 | # pick the alphabetically sorted first argument to complain about, 115 | # if multiple were given. At least the error message will be 116 | # predictable. 117 | complaint = sorted(kws)[0] 118 | raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint)) 119 | tests = [] 120 | for name in dir(module): 121 | obj = getattr(module, name) 122 | if isinstance(obj, type) and issubclass(obj, case.TestCase): 123 | tests.append(self.loadTestsFromTestCase(obj)) 124 | 125 | load_tests = getattr(module, 'load_tests', None) 126 | tests = self.suiteClass(tests) 127 | if load_tests is not None: 128 | try: 129 | return load_tests(self, tests, pattern) 130 | except Exception as e: 131 | error_case, error_message = _make_failed_load_tests( 132 | module.__name__, e, self.suiteClass) 133 | self.errors.append(error_message) 134 | return error_case 135 | return tests 136 | 137 | def loadTestsFromName(self, name, module=None): 138 | """Return a suite of all test cases given a string specifier. 139 | 140 | The name may resolve either to a module, a test case class, a 141 | test method within a test case class, or a callable object which 142 | returns a TestCase or TestSuite instance. 143 | 144 | The method optionally resolves the names relative to a given module. 145 | """ 146 | parts = name.split('.') 147 | error_case, error_message = None, None 148 | if module is None: 149 | parts_copy = parts[:] 150 | while parts_copy: 151 | try: 152 | module_name = '.'.join(parts_copy) 153 | module = __import__(module_name) 154 | break 155 | except ImportError: 156 | next_attribute = parts_copy.pop() 157 | # Last error so we can give it to the user if needed. 158 | error_case, error_message = _make_failed_import_test( 159 | next_attribute, self.suiteClass) 160 | if not parts_copy: 161 | # Even the top level import failed: report that error. 162 | self.errors.append(error_message) 163 | return error_case 164 | parts = parts[1:] 165 | obj = module 166 | for part in parts: 167 | try: 168 | parent, obj = obj, getattr(obj, part) 169 | except AttributeError as e: 170 | # We can't traverse some part of the name. 171 | if (getattr(obj, '__path__', None) is not None 172 | and error_case is not None): 173 | # This is a package (no __path__ per importlib docs), and we 174 | # encountered an error importing something. We cannot tell 175 | # the difference between package.WrongNameTestClass and 176 | # package.wrong_module_name so we just report the 177 | # ImportError - it is more informative. 178 | self.errors.append(error_message) 179 | return error_case 180 | else: 181 | # Otherwise, we signal that an AttributeError has occurred. 182 | error_case, error_message = _make_failed_test( 183 | part, e, self.suiteClass, 184 | 'Failed to access attribute:\n%s' % ( 185 | traceback.format_exc(),)) 186 | self.errors.append(error_message) 187 | return error_case 188 | 189 | if isinstance(obj, types.ModuleType): 190 | return self.loadTestsFromModule(obj) 191 | elif isinstance(obj, type) and issubclass(obj, case.TestCase): 192 | return self.loadTestsFromTestCase(obj) 193 | elif (isinstance(obj, types.FunctionType) and 194 | isinstance(parent, type) and 195 | issubclass(parent, case.TestCase)): 196 | name = parts[-1] 197 | inst = parent(name) 198 | # static methods follow a different path 199 | if not isinstance(getattr(inst, name), types.FunctionType): 200 | return self.suiteClass([inst]) 201 | elif isinstance(obj, suite.TestSuite): 202 | return obj 203 | if callable(obj): 204 | test = obj() 205 | if isinstance(test, suite.TestSuite): 206 | return test 207 | elif isinstance(test, case.TestCase): 208 | return self.suiteClass([test]) 209 | else: 210 | raise TypeError("calling %s returned %s, not a test" % 211 | (obj, test)) 212 | else: 213 | raise TypeError("don't know how to make test from: %s" % obj) 214 | 215 | def loadTestsFromNames(self, names, module=None): 216 | """Return a suite of all test cases found using the given sequence 217 | of string specifiers. See 'loadTestsFromName()'. 218 | """ 219 | suites = [self.loadTestsFromName(name, module) for name in names] 220 | return self.suiteClass(suites) 221 | 222 | def getTestCaseNames(self, testCaseClass): 223 | """Return a sorted sequence of method names found within testCaseClass 224 | """ 225 | def isTestMethod(attrname, testCaseClass=testCaseClass, 226 | prefix=self.testMethodPrefix): 227 | return attrname.startswith(prefix) and \ 228 | callable(getattr(testCaseClass, attrname)) 229 | testFnNames = list(filter(isTestMethod, dir(testCaseClass))) 230 | if self.sortTestMethodsUsing: 231 | testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing)) 232 | return testFnNames 233 | 234 | def discover(self, start_dir, pattern='test*.py', top_level_dir=None): 235 | """Find and return all test modules from the specified start 236 | directory, recursing into subdirectories to find them and return all 237 | tests found within them. Only test files that match the pattern will 238 | be loaded. (Using shell style pattern matching.) 239 | 240 | All test modules must be importable from the top level of the project. 241 | If the start directory is not the top level directory then the top 242 | level directory must be specified separately. 243 | 244 | If a test package name (directory with '__init__.py') matches the 245 | pattern then the package will be checked for a 'load_tests' function. If 246 | this exists then it will be called with (loader, tests, pattern) unless 247 | the package has already had load_tests called from the same discovery 248 | invocation, in which case the package module object is not scanned for 249 | tests - this ensures that when a package uses discover to further 250 | discover child tests that infinite recursion does not happen. 251 | 252 | If load_tests exists then discovery does *not* recurse into the package, 253 | load_tests is responsible for loading all tests in the package. 254 | 255 | The pattern is deliberately not stored as a loader attribute so that 256 | packages can continue discovery themselves. top_level_dir is stored so 257 | load_tests does not need to pass this argument in to loader.discover(). 258 | 259 | Paths are sorted before being imported to ensure reproducible execution 260 | order even on filesystems with non-alphabetical ordering like ext3/4. 261 | """ 262 | set_implicit_top = False 263 | if top_level_dir is None and self._top_level_dir is not None: 264 | # make top_level_dir optional if called from load_tests in a package 265 | top_level_dir = self._top_level_dir 266 | elif top_level_dir is None: 267 | set_implicit_top = True 268 | top_level_dir = start_dir 269 | 270 | top_level_dir = os.path.abspath(top_level_dir) 271 | 272 | if not top_level_dir in sys.path: 273 | # all test modules must be importable from the top level directory 274 | # should we *unconditionally* put the start directory in first 275 | # in sys.path to minimise likelihood of conflicts between installed 276 | # modules and development versions? 277 | sys.path.insert(0, top_level_dir) 278 | self._top_level_dir = top_level_dir 279 | 280 | is_not_importable = False 281 | is_namespace = False 282 | tests = [] 283 | if os.path.isdir(os.path.abspath(start_dir)): 284 | start_dir = os.path.abspath(start_dir) 285 | if start_dir != top_level_dir: 286 | is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py')) 287 | else: 288 | # support for discovery from dotted module names 289 | try: 290 | __import__(start_dir) 291 | except ImportError: 292 | is_not_importable = True 293 | else: 294 | the_module = sys.modules[start_dir] 295 | top_part = start_dir.split('.')[0] 296 | try: 297 | start_dir = os.path.abspath( 298 | os.path.dirname((the_module.__file__))) 299 | except AttributeError: 300 | # look for namespace packages 301 | try: 302 | spec = the_module.__spec__ 303 | except AttributeError: 304 | spec = None 305 | 306 | if spec and spec.loader is None: 307 | if spec.submodule_search_locations is not None: 308 | is_namespace = True 309 | 310 | for path in the_module.__path__: 311 | if (not set_implicit_top and 312 | not path.startswith(top_level_dir)): 313 | continue 314 | self._top_level_dir = \ 315 | (path.split(the_module.__name__ 316 | .replace(".", os.path.sep))[0]) 317 | tests.extend(self._find_tests(path, 318 | pattern, 319 | namespace=True)) 320 | elif the_module.__name__ in sys.builtin_module_names: 321 | # builtin module 322 | raise TypeError('Can not use builtin modules ' 323 | 'as dotted module names') from None 324 | else: 325 | raise TypeError( 326 | 'don\'t know how to discover from {!r}' 327 | .format(the_module)) from None 328 | 329 | if set_implicit_top: 330 | if not is_namespace: 331 | self._top_level_dir = \ 332 | self._get_directory_containing_module(top_part) 333 | sys.path.remove(top_level_dir) 334 | else: 335 | sys.path.remove(top_level_dir) 336 | 337 | if is_not_importable: 338 | raise ImportError('Start directory is not importable: %r' % start_dir) 339 | 340 | if not is_namespace: 341 | tests = list(self._find_tests(start_dir, pattern)) 342 | return self.suiteClass(tests) 343 | 344 | def _get_directory_containing_module(self, module_name): 345 | module = sys.modules[module_name] 346 | full_path = os.path.abspath(module.__file__) 347 | 348 | if os.path.basename(full_path).lower().startswith('__init__.py'): 349 | return os.path.dirname(os.path.dirname(full_path)) 350 | else: 351 | # here we have been given a module rather than a package - so 352 | # all we can do is search the *same* directory the module is in 353 | # should an exception be raised instead 354 | return os.path.dirname(full_path) 355 | 356 | def _get_name_from_path(self, path): 357 | if path == self._top_level_dir: 358 | return '.' 359 | path = _jython_aware_splitext(os.path.normpath(path)) 360 | 361 | _relpath = os.path.relpath(path, self._top_level_dir) 362 | assert not os.path.isabs(_relpath), "Path must be within the project" 363 | assert not _relpath.startswith('..'), "Path must be within the project" 364 | 365 | name = _relpath.replace(os.path.sep, '.') 366 | return name 367 | 368 | def _get_module_from_name(self, name): 369 | __import__(name) 370 | return sys.modules[name] 371 | 372 | def _match_path(self, path, full_path, pattern): 373 | # override this method to use alternative matching strategy 374 | return fnmatch(path, pattern) 375 | 376 | def _find_tests(self, start_dir, pattern, namespace=False): 377 | """Used by discovery. Yields test suites it loads.""" 378 | # Handle the __init__ in this package 379 | name = self._get_name_from_path(start_dir) 380 | # name is '.' when start_dir == top_level_dir (and top_level_dir is by 381 | # definition not a package). 382 | if name != '.' and name not in self._loading_packages: 383 | # name is in self._loading_packages while we have called into 384 | # loadTestsFromModule with name. 385 | tests, should_recurse = self._find_test_path( 386 | start_dir, pattern, namespace) 387 | if tests is not None: 388 | yield tests 389 | if not should_recurse: 390 | # Either an error occurred, or load_tests was used by the 391 | # package. 392 | return 393 | # Handle the contents. 394 | paths = sorted(os.listdir(start_dir)) 395 | for path in paths: 396 | full_path = os.path.join(start_dir, path) 397 | tests, should_recurse = self._find_test_path( 398 | full_path, pattern, namespace) 399 | if tests is not None: 400 | yield tests 401 | if should_recurse: 402 | # we found a package that didn't use load_tests. 403 | name = self._get_name_from_path(full_path) 404 | self._loading_packages.add(name) 405 | try: 406 | yield from self._find_tests(full_path, pattern, namespace) 407 | finally: 408 | self._loading_packages.discard(name) 409 | 410 | def _find_test_path(self, full_path, pattern, namespace=False): 411 | """Used by discovery. 412 | 413 | Loads tests from a single file, or a directories' __init__.py when 414 | passed the directory. 415 | 416 | Returns a tuple (None_or_tests_from_file, should_recurse). 417 | """ 418 | basename = os.path.basename(full_path) 419 | if os.path.isfile(full_path): 420 | if not VALID_MODULE_NAME.match(basename): 421 | # valid Python identifiers only 422 | return None, False 423 | if not self._match_path(basename, full_path, pattern): 424 | return None, False 425 | # if the test file matches, load it 426 | name = self._get_name_from_path(full_path) 427 | try: 428 | module = self._get_module_from_name(name) 429 | except case.SkipTest as e: 430 | return _make_skipped_test(name, e, self.suiteClass), False 431 | except: 432 | error_case, error_message = \ 433 | _make_failed_import_test(name, self.suiteClass) 434 | self.errors.append(error_message) 435 | return error_case, False 436 | else: 437 | mod_file = os.path.abspath( 438 | getattr(module, '__file__', full_path)) 439 | realpath = _jython_aware_splitext( 440 | os.path.realpath(mod_file)) 441 | fullpath_noext = _jython_aware_splitext( 442 | os.path.realpath(full_path)) 443 | if realpath.lower() != fullpath_noext.lower(): 444 | module_dir = os.path.dirname(realpath) 445 | mod_name = _jython_aware_splitext( 446 | os.path.basename(full_path)) 447 | expected_dir = os.path.dirname(full_path) 448 | msg = ("%r module incorrectly imported from %r. Expected " 449 | "%r. Is this module globally installed?") 450 | raise ImportError( 451 | msg % (mod_name, module_dir, expected_dir)) 452 | return self.loadTestsFromModule(module, pattern=pattern), False 453 | elif os.path.isdir(full_path): 454 | if (not namespace and 455 | not os.path.isfile(os.path.join(full_path, '__init__.py'))): 456 | return None, False 457 | 458 | load_tests = None 459 | tests = None 460 | name = self._get_name_from_path(full_path) 461 | try: 462 | package = self._get_module_from_name(name) 463 | except case.SkipTest as e: 464 | return _make_skipped_test(name, e, self.suiteClass), False 465 | except: 466 | error_case, error_message = \ 467 | _make_failed_import_test(name, self.suiteClass) 468 | self.errors.append(error_message) 469 | return error_case, False 470 | else: 471 | load_tests = getattr(package, 'load_tests', None) 472 | # Mark this package as being in load_tests (possibly ;)) 473 | self._loading_packages.add(name) 474 | try: 475 | tests = self.loadTestsFromModule(package, pattern=pattern) 476 | if load_tests is not None: 477 | # loadTestsFromModule(package) has loaded tests for us. 478 | return tests, False 479 | return tests, True 480 | finally: 481 | self._loading_packages.discard(name) 482 | else: 483 | return None, False 484 | 485 | 486 | defaultTestLoader = TestLoader() 487 | 488 | 489 | def _makeLoader(prefix, sortUsing, suiteClass=None): 490 | loader = TestLoader() 491 | loader.sortTestMethodsUsing = sortUsing 492 | loader.testMethodPrefix = prefix 493 | if suiteClass: 494 | loader.suiteClass = suiteClass 495 | return loader 496 | 497 | def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp): 498 | return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass) 499 | 500 | def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp, 501 | suiteClass=suite.TestSuite): 502 | return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase( 503 | testCaseClass) 504 | 505 | def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp, 506 | suiteClass=suite.TestSuite): 507 | return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\ 508 | module) 509 | -------------------------------------------------------------------------------- /fastunit/main.py: -------------------------------------------------------------------------------- 1 | """Unittest main program""" 2 | 3 | import sys 4 | import argparse 5 | import os 6 | 7 | from . import loader, runner 8 | from .signals import installHandler 9 | 10 | __unittest = True 11 | 12 | MAIN_EXAMPLES = """\ 13 | Examples: 14 | %(prog)s test_module - run tests from test_module 15 | %(prog)s module.TestClass - run tests from module.TestClass 16 | %(prog)s module.Class.test_method - run specified test method 17 | %(prog)s path/to/test_file.py - run tests from test_file.py 18 | """ 19 | 20 | MODULE_EXAMPLES = """\ 21 | Examples: 22 | %(prog)s - run default set of tests 23 | %(prog)s MyTestSuite - run suite 'MyTestSuite' 24 | %(prog)s MyTestCase.testSomething - run MyTestCase.testSomething 25 | %(prog)s MyTestCase - run all 'test*' test methods 26 | in MyTestCase 27 | """ 28 | 29 | def _convert_name(name): 30 | # on Linux / Mac OS X 'foo.PY' is not importable, but on 31 | # Windows it is. Simpler to do a case insensitive match 32 | # a better check would be to check that the name is a 33 | # valid Python module name. 34 | if os.path.isfile(name) and name.lower().endswith('.py'): 35 | if os.path.isabs(name): 36 | rel_path = os.path.relpath(name, os.getcwd()) 37 | if os.path.isabs(rel_path) or rel_path.startswith(os.pardir): 38 | return name 39 | name = rel_path 40 | # on Windows both '\' and '/' are used as path 41 | # separators. Better to replace both than rely on os.path.sep 42 | return name[:-3].replace('\\', '.').replace('/', '.') 43 | return name 44 | 45 | def _convert_names(names): 46 | return [_convert_name(name) for name in names] 47 | 48 | 49 | class TestProgram(object): 50 | """A command-line program that runs a set of tests; this is primarily 51 | for making test modules conveniently executable. 52 | """ 53 | # defaults for testing 54 | module=None 55 | verbosity = 1 56 | failfast = catchbreak = buffer = progName = warnings = None 57 | _discovery_parser = None 58 | 59 | def __init__(self, module='__main__', defaultTest=None, argv=None, 60 | testRunner=None, testLoader=loader.defaultTestLoader, 61 | exit=True, verbosity=1, failfast=None, catchbreak=None, 62 | buffer=None, warnings=None, *, tb_locals=False): 63 | if isinstance(module, str): 64 | self.module = __import__(module) 65 | for part in module.split('.')[1:]: 66 | self.module = getattr(self.module, part) 67 | else: 68 | self.module = module 69 | if argv is None: 70 | argv = sys.argv 71 | 72 | self.exit = exit 73 | self.failfast = failfast 74 | self.catchbreak = catchbreak 75 | self.verbosity = verbosity 76 | self.buffer = buffer 77 | self.tb_locals = tb_locals 78 | if warnings is None and not sys.warnoptions: 79 | # even if DeprecationWarnings are ignored by default 80 | # print them anyway unless other warnings settings are 81 | # specified by the warnings arg or the -W python flag 82 | self.warnings = 'default' 83 | else: 84 | # here self.warnings is set either to the value passed 85 | # to the warnings args or to None. 86 | # If the user didn't pass a value self.warnings will 87 | # be None. This means that the behavior is unchanged 88 | # and depends on the values passed to -W. 89 | self.warnings = warnings 90 | self.defaultTest = defaultTest 91 | self.testRunner = testRunner 92 | self.testLoader = testLoader 93 | self.progName = os.path.basename(argv[0]) 94 | self.parseArgs(argv) 95 | self.runTests() 96 | 97 | def usageExit(self, msg=None): 98 | if msg: 99 | print(msg) 100 | if self._discovery_parser is None: 101 | self._initArgParsers() 102 | self._print_help() 103 | sys.exit(2) 104 | 105 | def _print_help(self, *args, **kwargs): 106 | if self.module is None: 107 | print(self._main_parser.format_help()) 108 | print(MAIN_EXAMPLES % {'prog': self.progName}) 109 | self._discovery_parser.print_help() 110 | else: 111 | print(self._main_parser.format_help()) 112 | print(MODULE_EXAMPLES % {'prog': self.progName}) 113 | 114 | def parseArgs(self, argv): 115 | self._initArgParsers() 116 | if self.module is None: 117 | if len(argv) > 1 and argv[1].lower() == 'discover': 118 | self._do_discovery(argv[2:]) 119 | return 120 | self._main_parser.parse_args(argv[1:], self) 121 | if not self.tests: 122 | # this allows "python -m unittest -v" to still work for 123 | # test discovery. 124 | self._do_discovery([]) 125 | return 126 | else: 127 | self._main_parser.parse_args(argv[1:], self) 128 | 129 | if self.tests: 130 | self.testNames = _convert_names(self.tests) 131 | if __name__ == '__main__': 132 | # to support python -m unittest ... 133 | self.module = None 134 | elif self.defaultTest is None: 135 | # createTests will load tests from self.module 136 | self.testNames = None 137 | elif isinstance(self.defaultTest, str): 138 | self.testNames = (self.defaultTest,) 139 | else: 140 | self.testNames = list(self.defaultTest) 141 | self.createTests() 142 | 143 | def createTests(self): 144 | if self.testNames is None: 145 | self.test = self.testLoader.loadTestsFromModule(self.module) 146 | else: 147 | self.test = self.testLoader.loadTestsFromNames(self.testNames, 148 | self.module) 149 | 150 | def _initArgParsers(self): 151 | parent_parser = self._getParentArgParser() 152 | self._main_parser = self._getMainArgParser(parent_parser) 153 | self._discovery_parser = self._getDiscoveryArgParser(parent_parser) 154 | 155 | def _getParentArgParser(self): 156 | parser = argparse.ArgumentParser(add_help=False) 157 | 158 | parser.add_argument('-v', '--verbose', dest='verbosity', 159 | action='store_const', const=2, 160 | help='Verbose output') 161 | parser.add_argument('-q', '--quiet', dest='verbosity', 162 | action='store_const', const=0, 163 | help='Quiet output') 164 | parser.add_argument('--locals', dest='tb_locals', 165 | action='store_true', 166 | help='Show local variables in tracebacks') 167 | if self.failfast is None: 168 | parser.add_argument('-f', '--failfast', dest='failfast', 169 | action='store_true', 170 | help='Stop on first fail or error') 171 | self.failfast = False 172 | if self.catchbreak is None: 173 | parser.add_argument('-c', '--catch', dest='catchbreak', 174 | action='store_true', 175 | help='Catch Ctrl-C and display results so far') 176 | self.catchbreak = False 177 | if self.buffer is None: 178 | parser.add_argument('-b', '--buffer', dest='buffer', 179 | action='store_true', 180 | help='Buffer stdout and stderr during tests') 181 | self.buffer = False 182 | 183 | return parser 184 | 185 | def _getMainArgParser(self, parent): 186 | parser = argparse.ArgumentParser(parents=[parent]) 187 | parser.prog = self.progName 188 | parser.print_help = self._print_help 189 | 190 | parser.add_argument('tests', nargs='*', 191 | help='a list of any number of test modules, ' 192 | 'classes and test methods.') 193 | 194 | return parser 195 | 196 | def _getDiscoveryArgParser(self, parent): 197 | parser = argparse.ArgumentParser(parents=[parent]) 198 | parser.prog = '%s discover' % self.progName 199 | parser.epilog = ('For test discovery all test modules must be ' 200 | 'importable from the top level directory of the ' 201 | 'project.') 202 | 203 | parser.add_argument('-s', '--start-directory', dest='start', 204 | help="Directory to start discovery ('.' default)") 205 | parser.add_argument('-p', '--pattern', dest='pattern', 206 | help="Pattern to match tests ('test*.py' default)") 207 | parser.add_argument('-t', '--top-level-directory', dest='top', 208 | help='Top level directory of project (defaults to ' 209 | 'start directory)') 210 | for arg in ('start', 'pattern', 'top'): 211 | parser.add_argument(arg, nargs='?', 212 | default=argparse.SUPPRESS, 213 | help=argparse.SUPPRESS) 214 | 215 | return parser 216 | 217 | def _do_discovery(self, argv, Loader=None): 218 | self.start = '.' 219 | self.pattern = 'test*.py' 220 | self.top = None 221 | if argv is not None: 222 | # handle command line args for test discovery 223 | if self._discovery_parser is None: 224 | # for testing 225 | self._initArgParsers() 226 | self._discovery_parser.parse_args(argv, self) 227 | 228 | loader = self.testLoader if Loader is None else Loader() 229 | self.test = loader.discover(self.start, self.pattern, self.top) 230 | 231 | def runTests(self): 232 | if self.catchbreak: 233 | installHandler() 234 | if self.testRunner is None: 235 | self.testRunner = runner.TextTestRunner 236 | if isinstance(self.testRunner, type): 237 | try: 238 | try: 239 | testRunner = self.testRunner(verbosity=self.verbosity, 240 | failfast=self.failfast, 241 | buffer=self.buffer, 242 | warnings=self.warnings, 243 | tb_locals=self.tb_locals) 244 | except TypeError: 245 | # didn't accept the tb_locals argument 246 | testRunner = self.testRunner(verbosity=self.verbosity, 247 | failfast=self.failfast, 248 | buffer=self.buffer, 249 | warnings=self.warnings) 250 | except TypeError: 251 | # didn't accept the verbosity, buffer or failfast arguments 252 | testRunner = self.testRunner() 253 | else: 254 | # it is assumed to be a TestRunner instance 255 | testRunner = self.testRunner 256 | self.result = testRunner.run(self.test) 257 | if self.exit: 258 | sys.exit(not self.result.wasSuccessful()) 259 | 260 | main = TestProgram 261 | -------------------------------------------------------------------------------- /fastunit/mock.py: -------------------------------------------------------------------------------- 1 | # mock.py 2 | # Test tools for mocking and patching. 3 | # Maintained by Michael Foord 4 | # Backport for other versions of Python available from 5 | # http://pypi.python.org/pypi/mock 6 | 7 | __all__ = ( 8 | 'Mock', 9 | 'MagicMock', 10 | 'patch', 11 | 'sentinel', 12 | 'DEFAULT', 13 | 'ANY', 14 | 'call', 15 | 'create_autospec', 16 | 'FILTER_DIR', 17 | 'NonCallableMock', 18 | 'NonCallableMagicMock', 19 | 'mock_open', 20 | 'PropertyMock', 21 | ) 22 | 23 | 24 | __version__ = '1.0' 25 | 26 | 27 | import inspect 28 | import pprint 29 | import sys 30 | import builtins 31 | from types import ModuleType 32 | from functools import wraps, partial 33 | 34 | 35 | _builtins = {name for name in dir(builtins) if not name.startswith('_')} 36 | 37 | BaseExceptions = (BaseException,) 38 | if 'java' in sys.platform: 39 | # jython 40 | import java 41 | BaseExceptions = (BaseException, java.lang.Throwable) 42 | 43 | 44 | FILTER_DIR = True 45 | 46 | # Workaround for issue #12370 47 | # Without this, the __class__ properties wouldn't be set correctly 48 | _safe_super = super 49 | 50 | def _is_instance_mock(obj): 51 | # can't use isinstance on Mock objects because they override __class__ 52 | # The base class for all mocks is NonCallableMock 53 | return issubclass(type(obj), NonCallableMock) 54 | 55 | 56 | def _is_exception(obj): 57 | return ( 58 | isinstance(obj, BaseExceptions) or 59 | isinstance(obj, type) and issubclass(obj, BaseExceptions) 60 | ) 61 | 62 | 63 | class _slotted(object): 64 | __slots__ = ['a'] 65 | 66 | 67 | # Do not use this tuple. It was never documented as a public API. 68 | # It will be removed. It has no obvious signs of users on github. 69 | DescriptorTypes = ( 70 | type(_slotted.a), 71 | property, 72 | ) 73 | 74 | 75 | def _get_signature_object(func, as_instance, eat_self): 76 | """ 77 | Given an arbitrary, possibly callable object, try to create a suitable 78 | signature object. 79 | Return a (reduced func, signature) tuple, or None. 80 | """ 81 | if isinstance(func, type) and not as_instance: 82 | # If it's a type and should be modelled as a type, use __init__. 83 | try: 84 | func = func.__init__ 85 | except AttributeError: 86 | return None 87 | # Skip the `self` argument in __init__ 88 | eat_self = True 89 | elif not isinstance(func, FunctionTypes): 90 | # If we really want to model an instance of the passed type, 91 | # __call__ should be looked up, not __init__. 92 | try: 93 | func = func.__call__ 94 | except AttributeError: 95 | return None 96 | if eat_self: 97 | sig_func = partial(func, None) 98 | else: 99 | sig_func = func 100 | try: 101 | return func, inspect.signature(sig_func) 102 | except ValueError: 103 | # Certain callable types are not supported by inspect.signature() 104 | return None 105 | 106 | 107 | def _check_signature(func, mock, skipfirst, instance=False): 108 | sig = _get_signature_object(func, instance, skipfirst) 109 | if sig is None: 110 | return 111 | func, sig = sig 112 | def checksig(_mock_self, *args, **kwargs): 113 | sig.bind(*args, **kwargs) 114 | _copy_func_details(func, checksig) 115 | type(mock)._mock_check_sig = checksig 116 | 117 | 118 | def _copy_func_details(func, funcopy): 119 | funcopy.__name__ = func.__name__ 120 | funcopy.__doc__ = func.__doc__ 121 | try: 122 | funcopy.__text_signature__ = func.__text_signature__ 123 | except AttributeError: 124 | pass 125 | # we explicitly don't copy func.__dict__ into this copy as it would 126 | # expose original attributes that should be mocked 127 | try: 128 | funcopy.__module__ = func.__module__ 129 | except AttributeError: 130 | pass 131 | try: 132 | funcopy.__defaults__ = func.__defaults__ 133 | except AttributeError: 134 | pass 135 | try: 136 | funcopy.__kwdefaults__ = func.__kwdefaults__ 137 | except AttributeError: 138 | pass 139 | 140 | 141 | def _callable(obj): 142 | if isinstance(obj, type): 143 | return True 144 | if getattr(obj, '__call__', None) is not None: 145 | return True 146 | return False 147 | 148 | 149 | def _is_list(obj): 150 | # checks for list or tuples 151 | # XXXX badly named! 152 | return type(obj) in (list, tuple) 153 | 154 | 155 | def _instance_callable(obj): 156 | """Given an object, return True if the object is callable. 157 | For classes, return True if instances would be callable.""" 158 | if not isinstance(obj, type): 159 | # already an instance 160 | return getattr(obj, '__call__', None) is not None 161 | 162 | # *could* be broken by a class overriding __mro__ or __dict__ via 163 | # a metaclass 164 | for base in (obj,) + obj.__mro__: 165 | if base.__dict__.get('__call__') is not None: 166 | return True 167 | return False 168 | 169 | 170 | def _set_signature(mock, original, instance=False): 171 | # creates a function with signature (*args, **kwargs) that delegates to a 172 | # mock. It still does signature checking by calling a lambda with the same 173 | # signature as the original. 174 | if not _callable(original): 175 | return 176 | 177 | skipfirst = isinstance(original, type) 178 | result = _get_signature_object(original, instance, skipfirst) 179 | if result is None: 180 | return mock 181 | func, sig = result 182 | def checksig(*args, **kwargs): 183 | sig.bind(*args, **kwargs) 184 | _copy_func_details(func, checksig) 185 | 186 | name = original.__name__ 187 | if not name.isidentifier(): 188 | name = 'funcopy' 189 | context = {'_checksig_': checksig, 'mock': mock} 190 | src = """def %s(*args, **kwargs): 191 | _checksig_(*args, **kwargs) 192 | return mock(*args, **kwargs)""" % name 193 | exec (src, context) 194 | funcopy = context[name] 195 | _setup_func(funcopy, mock) 196 | return funcopy 197 | 198 | 199 | def _setup_func(funcopy, mock): 200 | funcopy.mock = mock 201 | 202 | # can't use isinstance with mocks 203 | if not _is_instance_mock(mock): 204 | return 205 | 206 | def assert_called_with(*args, **kwargs): 207 | return mock.assert_called_with(*args, **kwargs) 208 | def assert_called_once_with(*args, **kwargs): 209 | return mock.assert_called_once_with(*args, **kwargs) 210 | def assert_has_calls(*args, **kwargs): 211 | return mock.assert_has_calls(*args, **kwargs) 212 | def assert_any_call(*args, **kwargs): 213 | return mock.assert_any_call(*args, **kwargs) 214 | def reset_mock(): 215 | funcopy.method_calls = _CallList() 216 | funcopy.mock_calls = _CallList() 217 | mock.reset_mock() 218 | ret = funcopy.return_value 219 | if _is_instance_mock(ret) and not ret is mock: 220 | ret.reset_mock() 221 | 222 | funcopy.called = False 223 | funcopy.call_count = 0 224 | funcopy.call_args = None 225 | funcopy.call_args_list = _CallList() 226 | funcopy.method_calls = _CallList() 227 | funcopy.mock_calls = _CallList() 228 | 229 | funcopy.return_value = mock.return_value 230 | funcopy.side_effect = mock.side_effect 231 | funcopy._mock_children = mock._mock_children 232 | 233 | funcopy.assert_called_with = assert_called_with 234 | funcopy.assert_called_once_with = assert_called_once_with 235 | funcopy.assert_has_calls = assert_has_calls 236 | funcopy.assert_any_call = assert_any_call 237 | funcopy.reset_mock = reset_mock 238 | 239 | mock._mock_delegate = funcopy 240 | 241 | 242 | def _is_magic(name): 243 | return '__%s__' % name[2:-2] == name 244 | 245 | 246 | class _SentinelObject(object): 247 | "A unique, named, sentinel object." 248 | def __init__(self, name): 249 | self.name = name 250 | 251 | def __repr__(self): 252 | return 'sentinel.%s' % self.name 253 | 254 | 255 | class _Sentinel(object): 256 | """Access attributes to return a named object, usable as a sentinel.""" 257 | def __init__(self): 258 | self._sentinels = {} 259 | 260 | def __getattr__(self, name): 261 | if name == '__bases__': 262 | # Without this help(unittest.mock) raises an exception 263 | raise AttributeError 264 | return self._sentinels.setdefault(name, _SentinelObject(name)) 265 | 266 | 267 | sentinel = _Sentinel() 268 | 269 | DEFAULT = sentinel.DEFAULT 270 | _missing = sentinel.MISSING 271 | _deleted = sentinel.DELETED 272 | 273 | 274 | def _copy(value): 275 | if type(value) in (dict, list, tuple, set): 276 | return type(value)(value) 277 | return value 278 | 279 | 280 | _allowed_names = { 281 | 'return_value', '_mock_return_value', 'side_effect', 282 | '_mock_side_effect', '_mock_parent', '_mock_new_parent', 283 | '_mock_name', '_mock_new_name' 284 | } 285 | 286 | 287 | def _delegating_property(name): 288 | _allowed_names.add(name) 289 | _the_name = '_mock_' + name 290 | def _get(self, name=name, _the_name=_the_name): 291 | sig = self._mock_delegate 292 | if sig is None: 293 | return getattr(self, _the_name) 294 | return getattr(sig, name) 295 | def _set(self, value, name=name, _the_name=_the_name): 296 | sig = self._mock_delegate 297 | if sig is None: 298 | self.__dict__[_the_name] = value 299 | else: 300 | setattr(sig, name, value) 301 | 302 | return property(_get, _set) 303 | 304 | 305 | 306 | class _CallList(list): 307 | 308 | def __contains__(self, value): 309 | if not isinstance(value, list): 310 | return list.__contains__(self, value) 311 | len_value = len(value) 312 | len_self = len(self) 313 | if len_value > len_self: 314 | return False 315 | 316 | for i in range(0, len_self - len_value + 1): 317 | sub_list = self[i:i+len_value] 318 | if sub_list == value: 319 | return True 320 | return False 321 | 322 | def __repr__(self): 323 | return pprint.pformat(list(self)) 324 | 325 | 326 | def _check_and_set_parent(parent, value, name, new_name): 327 | if not _is_instance_mock(value): 328 | return False 329 | if ((value._mock_name or value._mock_new_name) or 330 | (value._mock_parent is not None) or 331 | (value._mock_new_parent is not None)): 332 | return False 333 | 334 | _parent = parent 335 | while _parent is not None: 336 | # setting a mock (value) as a child or return value of itself 337 | # should not modify the mock 338 | if _parent is value: 339 | return False 340 | _parent = _parent._mock_new_parent 341 | 342 | if new_name: 343 | value._mock_new_parent = parent 344 | value._mock_new_name = new_name 345 | if name: 346 | value._mock_parent = parent 347 | value._mock_name = name 348 | return True 349 | 350 | # Internal class to identify if we wrapped an iterator object or not. 351 | class _MockIter(object): 352 | def __init__(self, obj): 353 | self.obj = iter(obj) 354 | def __iter__(self): 355 | return self 356 | def __next__(self): 357 | return next(self.obj) 358 | 359 | class Base(object): 360 | _mock_return_value = DEFAULT 361 | _mock_side_effect = None 362 | def __init__(self, *args, **kwargs): 363 | pass 364 | 365 | 366 | 367 | class NonCallableMock(Base): 368 | """A non-callable version of `Mock`""" 369 | 370 | def __new__(cls, *args, **kw): 371 | # every instance has its own class 372 | # so we can create magic methods on the 373 | # class without stomping on other mocks 374 | new = type(cls.__name__, (cls,), {'__doc__': cls.__doc__}) 375 | instance = object.__new__(new) 376 | return instance 377 | 378 | 379 | def __init__( 380 | self, spec=None, wraps=None, name=None, spec_set=None, 381 | parent=None, _spec_state=None, _new_name='', _new_parent=None, 382 | _spec_as_instance=False, _eat_self=None, unsafe=False, **kwargs 383 | ): 384 | if _new_parent is None: 385 | _new_parent = parent 386 | 387 | __dict__ = self.__dict__ 388 | __dict__['_mock_parent'] = parent 389 | __dict__['_mock_name'] = name 390 | __dict__['_mock_new_name'] = _new_name 391 | __dict__['_mock_new_parent'] = _new_parent 392 | 393 | if spec_set is not None: 394 | spec = spec_set 395 | spec_set = True 396 | if _eat_self is None: 397 | _eat_self = parent is not None 398 | 399 | self._mock_add_spec(spec, spec_set, _spec_as_instance, _eat_self) 400 | 401 | __dict__['_mock_children'] = {} 402 | __dict__['_mock_wraps'] = wraps 403 | __dict__['_mock_delegate'] = None 404 | 405 | __dict__['_mock_called'] = False 406 | __dict__['_mock_call_args'] = None 407 | __dict__['_mock_call_count'] = 0 408 | __dict__['_mock_call_args_list'] = _CallList() 409 | __dict__['_mock_mock_calls'] = _CallList() 410 | 411 | __dict__['method_calls'] = _CallList() 412 | __dict__['_mock_unsafe'] = unsafe 413 | 414 | if kwargs: 415 | self.configure_mock(**kwargs) 416 | 417 | _safe_super(NonCallableMock, self).__init__( 418 | spec, wraps, name, spec_set, parent, 419 | _spec_state 420 | ) 421 | 422 | 423 | def attach_mock(self, mock, attribute): 424 | """ 425 | Attach a mock as an attribute of this one, replacing its name and 426 | parent. Calls to the attached mock will be recorded in the 427 | `method_calls` and `mock_calls` attributes of this one.""" 428 | mock._mock_parent = None 429 | mock._mock_new_parent = None 430 | mock._mock_name = '' 431 | mock._mock_new_name = None 432 | 433 | setattr(self, attribute, mock) 434 | 435 | 436 | def mock_add_spec(self, spec, spec_set=False): 437 | """Add a spec to a mock. `spec` can either be an object or a 438 | list of strings. Only attributes on the `spec` can be fetched as 439 | attributes from the mock. 440 | 441 | If `spec_set` is True then only attributes on the spec can be set.""" 442 | self._mock_add_spec(spec, spec_set) 443 | 444 | 445 | def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False, 446 | _eat_self=False): 447 | _spec_class = None 448 | _spec_signature = None 449 | 450 | if spec is not None and not _is_list(spec): 451 | if isinstance(spec, type): 452 | _spec_class = spec 453 | else: 454 | _spec_class = _get_class(spec) 455 | res = _get_signature_object(spec, 456 | _spec_as_instance, _eat_self) 457 | _spec_signature = res and res[1] 458 | 459 | spec = dir(spec) 460 | 461 | __dict__ = self.__dict__ 462 | __dict__['_spec_class'] = _spec_class 463 | __dict__['_spec_set'] = spec_set 464 | __dict__['_spec_signature'] = _spec_signature 465 | __dict__['_mock_methods'] = spec 466 | 467 | 468 | def __get_return_value(self): 469 | ret = self._mock_return_value 470 | if self._mock_delegate is not None: 471 | ret = self._mock_delegate.return_value 472 | 473 | if ret is DEFAULT: 474 | ret = self._get_child_mock( 475 | _new_parent=self, _new_name='()' 476 | ) 477 | self.return_value = ret 478 | return ret 479 | 480 | 481 | def __set_return_value(self, value): 482 | if self._mock_delegate is not None: 483 | self._mock_delegate.return_value = value 484 | else: 485 | self._mock_return_value = value 486 | _check_and_set_parent(self, value, None, '()') 487 | 488 | __return_value_doc = "The value to be returned when the mock is called." 489 | return_value = property(__get_return_value, __set_return_value, 490 | __return_value_doc) 491 | 492 | 493 | @property 494 | def __class__(self): 495 | if self._spec_class is None: 496 | return type(self) 497 | return self._spec_class 498 | 499 | called = _delegating_property('called') 500 | call_count = _delegating_property('call_count') 501 | call_args = _delegating_property('call_args') 502 | call_args_list = _delegating_property('call_args_list') 503 | mock_calls = _delegating_property('mock_calls') 504 | 505 | 506 | def __get_side_effect(self): 507 | delegated = self._mock_delegate 508 | if delegated is None: 509 | return self._mock_side_effect 510 | sf = delegated.side_effect 511 | if (sf is not None and not callable(sf) 512 | and not isinstance(sf, _MockIter) and not _is_exception(sf)): 513 | sf = _MockIter(sf) 514 | delegated.side_effect = sf 515 | return sf 516 | 517 | def __set_side_effect(self, value): 518 | value = _try_iter(value) 519 | delegated = self._mock_delegate 520 | if delegated is None: 521 | self._mock_side_effect = value 522 | else: 523 | delegated.side_effect = value 524 | 525 | side_effect = property(__get_side_effect, __set_side_effect) 526 | 527 | 528 | def reset_mock(self, visited=None): 529 | "Restore the mock object to its initial state." 530 | if visited is None: 531 | visited = [] 532 | if id(self) in visited: 533 | return 534 | visited.append(id(self)) 535 | 536 | self.called = False 537 | self.call_args = None 538 | self.call_count = 0 539 | self.mock_calls = _CallList() 540 | self.call_args_list = _CallList() 541 | self.method_calls = _CallList() 542 | 543 | for child in self._mock_children.values(): 544 | if isinstance(child, _SpecState): 545 | continue 546 | child.reset_mock(visited) 547 | 548 | ret = self._mock_return_value 549 | if _is_instance_mock(ret) and ret is not self: 550 | ret.reset_mock(visited) 551 | 552 | 553 | def configure_mock(self, **kwargs): 554 | """Set attributes on the mock through keyword arguments. 555 | 556 | Attributes plus return values and side effects can be set on child 557 | mocks using standard dot notation and unpacking a dictionary in the 558 | method call: 559 | 560 | >>> attrs = {'method.return_value': 3, 'other.side_effect': KeyError} 561 | >>> mock.configure_mock(**attrs)""" 562 | for arg, val in sorted(kwargs.items(), 563 | # we sort on the number of dots so that 564 | # attributes are set before we set attributes on 565 | # attributes 566 | key=lambda entry: entry[0].count('.')): 567 | args = arg.split('.') 568 | final = args.pop() 569 | obj = self 570 | for entry in args: 571 | obj = getattr(obj, entry) 572 | setattr(obj, final, val) 573 | 574 | 575 | def __getattr__(self, name): 576 | if name in {'_mock_methods', '_mock_unsafe'}: 577 | raise AttributeError(name) 578 | elif self._mock_methods is not None: 579 | if name not in self._mock_methods or name in _all_magics: 580 | raise AttributeError("Mock object has no attribute %r" % name) 581 | elif _is_magic(name): 582 | raise AttributeError(name) 583 | if not self._mock_unsafe: 584 | if name.startswith(('assert', 'assret')): 585 | raise AttributeError(name) 586 | 587 | result = self._mock_children.get(name) 588 | if result is _deleted: 589 | raise AttributeError(name) 590 | elif result is None: 591 | wraps = None 592 | if self._mock_wraps is not None: 593 | # XXXX should we get the attribute without triggering code 594 | # execution? 595 | wraps = getattr(self._mock_wraps, name) 596 | 597 | result = self._get_child_mock( 598 | parent=self, name=name, wraps=wraps, _new_name=name, 599 | _new_parent=self 600 | ) 601 | self._mock_children[name] = result 602 | 603 | elif isinstance(result, _SpecState): 604 | result = create_autospec( 605 | result.spec, result.spec_set, result.instance, 606 | result.parent, result.name 607 | ) 608 | self._mock_children[name] = result 609 | 610 | return result 611 | 612 | 613 | def __repr__(self): 614 | _name_list = [self._mock_new_name] 615 | _parent = self._mock_new_parent 616 | last = self 617 | 618 | dot = '.' 619 | if _name_list == ['()']: 620 | dot = '' 621 | seen = set() 622 | while _parent is not None: 623 | last = _parent 624 | 625 | _name_list.append(_parent._mock_new_name + dot) 626 | dot = '.' 627 | if _parent._mock_new_name == '()': 628 | dot = '' 629 | 630 | _parent = _parent._mock_new_parent 631 | 632 | # use ids here so as not to call __hash__ on the mocks 633 | if id(_parent) in seen: 634 | break 635 | seen.add(id(_parent)) 636 | 637 | _name_list = list(reversed(_name_list)) 638 | _first = last._mock_name or 'mock' 639 | if len(_name_list) > 1: 640 | if _name_list[1] not in ('()', '().'): 641 | _first += '.' 642 | _name_list[0] = _first 643 | name = ''.join(_name_list) 644 | 645 | name_string = '' 646 | if name not in ('mock', 'mock.'): 647 | name_string = ' name=%r' % name 648 | 649 | spec_string = '' 650 | if self._spec_class is not None: 651 | spec_string = ' spec=%r' 652 | if self._spec_set: 653 | spec_string = ' spec_set=%r' 654 | spec_string = spec_string % self._spec_class.__name__ 655 | return "<%s%s%s id='%s'>" % ( 656 | type(self).__name__, 657 | name_string, 658 | spec_string, 659 | id(self) 660 | ) 661 | 662 | 663 | def __dir__(self): 664 | """Filter the output of `dir(mock)` to only useful members.""" 665 | if not FILTER_DIR: 666 | return object.__dir__(self) 667 | 668 | extras = self._mock_methods or [] 669 | from_type = dir(type(self)) 670 | from_dict = list(self.__dict__) 671 | 672 | from_type = [e for e in from_type if not e.startswith('_')] 673 | from_dict = [e for e in from_dict if not e.startswith('_') or 674 | _is_magic(e)] 675 | return sorted(set(extras + from_type + from_dict + 676 | list(self._mock_children))) 677 | 678 | 679 | def __setattr__(self, name, value): 680 | if name in _allowed_names: 681 | # property setters go through here 682 | return object.__setattr__(self, name, value) 683 | elif (self._spec_set and self._mock_methods is not None and 684 | name not in self._mock_methods and 685 | name not in self.__dict__): 686 | raise AttributeError("Mock object has no attribute '%s'" % name) 687 | elif name in _unsupported_magics: 688 | msg = 'Attempting to set unsupported magic method %r.' % name 689 | raise AttributeError(msg) 690 | elif name in _all_magics: 691 | if self._mock_methods is not None and name not in self._mock_methods: 692 | raise AttributeError("Mock object has no attribute '%s'" % name) 693 | 694 | if not _is_instance_mock(value): 695 | setattr(type(self), name, _get_method(name, value)) 696 | original = value 697 | value = lambda *args, **kw: original(self, *args, **kw) 698 | else: 699 | # only set _new_name and not name so that mock_calls is tracked 700 | # but not method calls 701 | _check_and_set_parent(self, value, None, name) 702 | setattr(type(self), name, value) 703 | self._mock_children[name] = value 704 | elif name == '__class__': 705 | self._spec_class = value 706 | return 707 | else: 708 | if _check_and_set_parent(self, value, name, name): 709 | self._mock_children[name] = value 710 | return object.__setattr__(self, name, value) 711 | 712 | 713 | def __delattr__(self, name): 714 | if name in _all_magics and name in type(self).__dict__: 715 | delattr(type(self), name) 716 | if name not in self.__dict__: 717 | # for magic methods that are still MagicProxy objects and 718 | # not set on the instance itself 719 | return 720 | 721 | if name in self.__dict__: 722 | object.__delattr__(self, name) 723 | 724 | obj = self._mock_children.get(name, _missing) 725 | if obj is _deleted: 726 | raise AttributeError(name) 727 | if obj is not _missing: 728 | del self._mock_children[name] 729 | self._mock_children[name] = _deleted 730 | 731 | 732 | def _format_mock_call_signature(self, args, kwargs): 733 | name = self._mock_name or 'mock' 734 | return _format_call_signature(name, args, kwargs) 735 | 736 | 737 | def _format_mock_failure_message(self, args, kwargs): 738 | message = 'Expected call: %s\nActual call: %s' 739 | expected_string = self._format_mock_call_signature(args, kwargs) 740 | call_args = self.call_args 741 | if len(call_args) == 3: 742 | call_args = call_args[1:] 743 | actual_string = self._format_mock_call_signature(*call_args) 744 | return message % (expected_string, actual_string) 745 | 746 | 747 | def _call_matcher(self, _call): 748 | """ 749 | Given a call (or simply an (args, kwargs) tuple), return a 750 | comparison key suitable for matching with other calls. 751 | This is a best effort method which relies on the spec's signature, 752 | if available, or falls back on the arguments themselves. 753 | """ 754 | sig = self._spec_signature 755 | if sig is not None: 756 | if len(_call) == 2: 757 | name = '' 758 | args, kwargs = _call 759 | else: 760 | name, args, kwargs = _call 761 | try: 762 | return name, sig.bind(*args, **kwargs) 763 | except TypeError as e: 764 | return e.with_traceback(None) 765 | else: 766 | return _call 767 | 768 | def assert_not_called(_mock_self): 769 | """assert that the mock was never called. 770 | """ 771 | self = _mock_self 772 | if self.call_count != 0: 773 | msg = ("Expected '%s' to not have been called. Called %s times." % 774 | (self._mock_name or 'mock', self.call_count)) 775 | raise AssertionError(msg) 776 | 777 | def assert_called_with(_mock_self, *args, **kwargs): 778 | """assert that the mock was called with the specified arguments. 779 | 780 | Raises an AssertionError if the args and keyword args passed in are 781 | different to the last call to the mock.""" 782 | self = _mock_self 783 | if self.call_args is None: 784 | expected = self._format_mock_call_signature(args, kwargs) 785 | raise AssertionError('Expected call: %s\nNot called' % (expected,)) 786 | 787 | def _error_message(): 788 | msg = self._format_mock_failure_message(args, kwargs) 789 | return msg 790 | expected = self._call_matcher((args, kwargs)) 791 | actual = self._call_matcher(self.call_args) 792 | if expected != actual: 793 | cause = expected if isinstance(expected, Exception) else None 794 | raise AssertionError(_error_message()) from cause 795 | 796 | 797 | def assert_called_once_with(_mock_self, *args, **kwargs): 798 | """assert that the mock was called exactly once and that that call was 799 | with the specified arguments.""" 800 | self = _mock_self 801 | if not self.call_count == 1: 802 | msg = ("Expected '%s' to be called once. Called %s times." % 803 | (self._mock_name or 'mock', self.call_count)) 804 | raise AssertionError(msg) 805 | return self.assert_called_with(*args, **kwargs) 806 | 807 | 808 | def assert_has_calls(self, calls, any_order=False): 809 | """assert the mock has been called with the specified calls. 810 | The `mock_calls` list is checked for the calls. 811 | 812 | If `any_order` is False (the default) then the calls must be 813 | sequential. There can be extra calls before or after the 814 | specified calls. 815 | 816 | If `any_order` is True then the calls can be in any order, but 817 | they must all appear in `mock_calls`.""" 818 | expected = [self._call_matcher(c) for c in calls] 819 | cause = expected if isinstance(expected, Exception) else None 820 | all_calls = _CallList(self._call_matcher(c) for c in self.mock_calls) 821 | if not any_order: 822 | if expected not in all_calls: 823 | raise AssertionError( 824 | 'Calls not found.\nExpected: %r\n' 825 | 'Actual: %r' % (calls, self.mock_calls) 826 | ) from cause 827 | return 828 | 829 | all_calls = list(all_calls) 830 | 831 | not_found = [] 832 | for kall in expected: 833 | try: 834 | all_calls.remove(kall) 835 | except ValueError: 836 | not_found.append(kall) 837 | if not_found: 838 | raise AssertionError( 839 | '%r not all found in call list' % (tuple(not_found),) 840 | ) from cause 841 | 842 | 843 | def assert_any_call(self, *args, **kwargs): 844 | """assert the mock has been called with the specified arguments. 845 | 846 | The assert passes if the mock has *ever* been called, unlike 847 | `assert_called_with` and `assert_called_once_with` that only pass if 848 | the call is the most recent one.""" 849 | expected = self._call_matcher((args, kwargs)) 850 | actual = [self._call_matcher(c) for c in self.call_args_list] 851 | if expected not in actual: 852 | cause = expected if isinstance(expected, Exception) else None 853 | expected_string = self._format_mock_call_signature(args, kwargs) 854 | raise AssertionError( 855 | '%s call not found' % expected_string 856 | ) from cause 857 | 858 | 859 | def _get_child_mock(self, **kw): 860 | """Create the child mocks for attributes and return value. 861 | By default child mocks will be the same type as the parent. 862 | Subclasses of Mock may want to override this to customize the way 863 | child mocks are made. 864 | 865 | For non-callable mocks the callable variant will be used (rather than 866 | any custom subclass).""" 867 | _type = type(self) 868 | if not issubclass(_type, CallableMixin): 869 | if issubclass(_type, NonCallableMagicMock): 870 | klass = MagicMock 871 | elif issubclass(_type, NonCallableMock) : 872 | klass = Mock 873 | else: 874 | klass = _type.__mro__[1] 875 | return klass(**kw) 876 | 877 | 878 | 879 | def _try_iter(obj): 880 | if obj is None: 881 | return obj 882 | if _is_exception(obj): 883 | return obj 884 | if _callable(obj): 885 | return obj 886 | try: 887 | return iter(obj) 888 | except TypeError: 889 | # XXXX backwards compatibility 890 | # but this will blow up on first call - so maybe we should fail early? 891 | return obj 892 | 893 | 894 | 895 | class CallableMixin(Base): 896 | 897 | def __init__(self, spec=None, side_effect=None, return_value=DEFAULT, 898 | wraps=None, name=None, spec_set=None, parent=None, 899 | _spec_state=None, _new_name='', _new_parent=None, **kwargs): 900 | self.__dict__['_mock_return_value'] = return_value 901 | 902 | _safe_super(CallableMixin, self).__init__( 903 | spec, wraps, name, spec_set, parent, 904 | _spec_state, _new_name, _new_parent, **kwargs 905 | ) 906 | 907 | self.side_effect = side_effect 908 | 909 | 910 | def _mock_check_sig(self, *args, **kwargs): 911 | # stub method that can be replaced with one with a specific signature 912 | pass 913 | 914 | 915 | def __call__(_mock_self, *args, **kwargs): 916 | # can't use self in-case a function / method we are mocking uses self 917 | # in the signature 918 | _mock_self._mock_check_sig(*args, **kwargs) 919 | return _mock_self._mock_call(*args, **kwargs) 920 | 921 | 922 | def _mock_call(_mock_self, *args, **kwargs): 923 | self = _mock_self 924 | self.called = True 925 | self.call_count += 1 926 | _new_name = self._mock_new_name 927 | _new_parent = self._mock_new_parent 928 | 929 | _call = _Call((args, kwargs), two=True) 930 | self.call_args = _call 931 | self.call_args_list.append(_call) 932 | self.mock_calls.append(_Call(('', args, kwargs))) 933 | 934 | seen = set() 935 | skip_next_dot = _new_name == '()' 936 | do_method_calls = self._mock_parent is not None 937 | name = self._mock_name 938 | while _new_parent is not None: 939 | this_mock_call = _Call((_new_name, args, kwargs)) 940 | if _new_parent._mock_new_name: 941 | dot = '.' 942 | if skip_next_dot: 943 | dot = '' 944 | 945 | skip_next_dot = False 946 | if _new_parent._mock_new_name == '()': 947 | skip_next_dot = True 948 | 949 | _new_name = _new_parent._mock_new_name + dot + _new_name 950 | 951 | if do_method_calls: 952 | if _new_name == name: 953 | this_method_call = this_mock_call 954 | else: 955 | this_method_call = _Call((name, args, kwargs)) 956 | _new_parent.method_calls.append(this_method_call) 957 | 958 | do_method_calls = _new_parent._mock_parent is not None 959 | if do_method_calls: 960 | name = _new_parent._mock_name + '.' + name 961 | 962 | _new_parent.mock_calls.append(this_mock_call) 963 | _new_parent = _new_parent._mock_new_parent 964 | 965 | # use ids here so as not to call __hash__ on the mocks 966 | _new_parent_id = id(_new_parent) 967 | if _new_parent_id in seen: 968 | break 969 | seen.add(_new_parent_id) 970 | 971 | ret_val = DEFAULT 972 | effect = self.side_effect 973 | if effect is not None: 974 | if _is_exception(effect): 975 | raise effect 976 | 977 | if not _callable(effect): 978 | result = next(effect) 979 | if _is_exception(result): 980 | raise result 981 | if result is DEFAULT: 982 | result = self.return_value 983 | return result 984 | 985 | ret_val = effect(*args, **kwargs) 986 | 987 | if (self._mock_wraps is not None and 988 | self._mock_return_value is DEFAULT): 989 | return self._mock_wraps(*args, **kwargs) 990 | if ret_val is DEFAULT: 991 | ret_val = self.return_value 992 | return ret_val 993 | 994 | 995 | 996 | class Mock(CallableMixin, NonCallableMock): 997 | """ 998 | Create a new `Mock` object. `Mock` takes several optional arguments 999 | that specify the behaviour of the Mock object: 1000 | 1001 | * `spec`: This can be either a list of strings or an existing object (a 1002 | class or instance) that acts as the specification for the mock object. If 1003 | you pass in an object then a list of strings is formed by calling dir on 1004 | the object (excluding unsupported magic attributes and methods). Accessing 1005 | any attribute not in this list will raise an `AttributeError`. 1006 | 1007 | If `spec` is an object (rather than a list of strings) then 1008 | `mock.__class__` returns the class of the spec object. This allows mocks 1009 | to pass `isinstance` tests. 1010 | 1011 | * `spec_set`: A stricter variant of `spec`. If used, attempting to *set* 1012 | or get an attribute on the mock that isn't on the object passed as 1013 | `spec_set` will raise an `AttributeError`. 1014 | 1015 | * `side_effect`: A function to be called whenever the Mock is called. See 1016 | the `side_effect` attribute. Useful for raising exceptions or 1017 | dynamically changing return values. The function is called with the same 1018 | arguments as the mock, and unless it returns `DEFAULT`, the return 1019 | value of this function is used as the return value. 1020 | 1021 | If `side_effect` is an iterable then each call to the mock will return 1022 | the next value from the iterable. If any of the members of the iterable 1023 | are exceptions they will be raised instead of returned. 1024 | 1025 | * `return_value`: The value returned when the mock is called. By default 1026 | this is a new Mock (created on first access). See the 1027 | `return_value` attribute. 1028 | 1029 | * `wraps`: Item for the mock object to wrap. If `wraps` is not None then 1030 | calling the Mock will pass the call through to the wrapped object 1031 | (returning the real result). Attribute access on the mock will return a 1032 | Mock object that wraps the corresponding attribute of the wrapped object 1033 | (so attempting to access an attribute that doesn't exist will raise an 1034 | `AttributeError`). 1035 | 1036 | If the mock has an explicit `return_value` set then calls are not passed 1037 | to the wrapped object and the `return_value` is returned instead. 1038 | 1039 | * `name`: If the mock has a name then it will be used in the repr of the 1040 | mock. This can be useful for debugging. The name is propagated to child 1041 | mocks. 1042 | 1043 | Mocks can also be called with arbitrary keyword arguments. These will be 1044 | used to set attributes on the mock after it is created. 1045 | """ 1046 | 1047 | 1048 | 1049 | def _dot_lookup(thing, comp, import_path): 1050 | try: 1051 | return getattr(thing, comp) 1052 | except AttributeError: 1053 | __import__(import_path) 1054 | return getattr(thing, comp) 1055 | 1056 | 1057 | def _importer(target): 1058 | components = target.split('.') 1059 | import_path = components.pop(0) 1060 | thing = __import__(import_path) 1061 | 1062 | for comp in components: 1063 | import_path += ".%s" % comp 1064 | thing = _dot_lookup(thing, comp, import_path) 1065 | return thing 1066 | 1067 | 1068 | def _is_started(patcher): 1069 | # XXXX horrible 1070 | return hasattr(patcher, 'is_local') 1071 | 1072 | 1073 | class _patch(object): 1074 | 1075 | attribute_name = None 1076 | _active_patches = [] 1077 | 1078 | def __init__( 1079 | self, getter, attribute, new, spec, create, 1080 | spec_set, autospec, new_callable, kwargs 1081 | ): 1082 | if new_callable is not None: 1083 | if new is not DEFAULT: 1084 | raise ValueError( 1085 | "Cannot use 'new' and 'new_callable' together" 1086 | ) 1087 | if autospec is not None: 1088 | raise ValueError( 1089 | "Cannot use 'autospec' and 'new_callable' together" 1090 | ) 1091 | 1092 | self.getter = getter 1093 | self.attribute = attribute 1094 | self.new = new 1095 | self.new_callable = new_callable 1096 | self.spec = spec 1097 | self.create = create 1098 | self.has_local = False 1099 | self.spec_set = spec_set 1100 | self.autospec = autospec 1101 | self.kwargs = kwargs 1102 | self.additional_patchers = [] 1103 | 1104 | 1105 | def copy(self): 1106 | patcher = _patch( 1107 | self.getter, self.attribute, self.new, self.spec, 1108 | self.create, self.spec_set, 1109 | self.autospec, self.new_callable, self.kwargs 1110 | ) 1111 | patcher.attribute_name = self.attribute_name 1112 | patcher.additional_patchers = [ 1113 | p.copy() for p in self.additional_patchers 1114 | ] 1115 | return patcher 1116 | 1117 | 1118 | def __call__(self, func): 1119 | if isinstance(func, type): 1120 | return self.decorate_class(func) 1121 | return self.decorate_callable(func) 1122 | 1123 | 1124 | def decorate_class(self, klass): 1125 | for attr in dir(klass): 1126 | if not attr.startswith(patch.TEST_PREFIX): 1127 | continue 1128 | 1129 | attr_value = getattr(klass, attr) 1130 | if not hasattr(attr_value, "__call__"): 1131 | continue 1132 | 1133 | patcher = self.copy() 1134 | setattr(klass, attr, patcher(attr_value)) 1135 | return klass 1136 | 1137 | 1138 | def decorate_callable(self, func): 1139 | if hasattr(func, 'patchings'): 1140 | func.patchings.append(self) 1141 | return func 1142 | 1143 | @wraps(func) 1144 | def patched(*args, **keywargs): 1145 | extra_args = [] 1146 | entered_patchers = [] 1147 | 1148 | exc_info = tuple() 1149 | try: 1150 | for patching in patched.patchings: 1151 | arg = patching.__enter__() 1152 | entered_patchers.append(patching) 1153 | if patching.attribute_name is not None: 1154 | keywargs.update(arg) 1155 | elif patching.new is DEFAULT: 1156 | extra_args.append(arg) 1157 | 1158 | args += tuple(extra_args) 1159 | return func(*args, **keywargs) 1160 | except: 1161 | if (patching not in entered_patchers and 1162 | _is_started(patching)): 1163 | # the patcher may have been started, but an exception 1164 | # raised whilst entering one of its additional_patchers 1165 | entered_patchers.append(patching) 1166 | # Pass the exception to __exit__ 1167 | exc_info = sys.exc_info() 1168 | # re-raise the exception 1169 | raise 1170 | finally: 1171 | for patching in reversed(entered_patchers): 1172 | patching.__exit__(*exc_info) 1173 | 1174 | patched.patchings = [self] 1175 | return patched 1176 | 1177 | 1178 | def get_original(self): 1179 | target = self.getter() 1180 | name = self.attribute 1181 | 1182 | original = DEFAULT 1183 | local = False 1184 | 1185 | try: 1186 | original = target.__dict__[name] 1187 | except (AttributeError, KeyError): 1188 | original = getattr(target, name, DEFAULT) 1189 | else: 1190 | local = True 1191 | 1192 | if name in _builtins and isinstance(target, ModuleType): 1193 | self.create = True 1194 | 1195 | if not self.create and original is DEFAULT: 1196 | raise AttributeError( 1197 | "%s does not have the attribute %r" % (target, name) 1198 | ) 1199 | return original, local 1200 | 1201 | 1202 | def __enter__(self): 1203 | """Perform the patch.""" 1204 | new, spec, spec_set = self.new, self.spec, self.spec_set 1205 | autospec, kwargs = self.autospec, self.kwargs 1206 | new_callable = self.new_callable 1207 | self.target = self.getter() 1208 | 1209 | # normalise False to None 1210 | if spec is False: 1211 | spec = None 1212 | if spec_set is False: 1213 | spec_set = None 1214 | if autospec is False: 1215 | autospec = None 1216 | 1217 | if spec is not None and autospec is not None: 1218 | raise TypeError("Can't specify spec and autospec") 1219 | if ((spec is not None or autospec is not None) and 1220 | spec_set not in (True, None)): 1221 | raise TypeError("Can't provide explicit spec_set *and* spec or autospec") 1222 | 1223 | original, local = self.get_original() 1224 | 1225 | if new is DEFAULT and autospec is None: 1226 | inherit = False 1227 | if spec is True: 1228 | # set spec to the object we are replacing 1229 | spec = original 1230 | if spec_set is True: 1231 | spec_set = original 1232 | spec = None 1233 | elif spec is not None: 1234 | if spec_set is True: 1235 | spec_set = spec 1236 | spec = None 1237 | elif spec_set is True: 1238 | spec_set = original 1239 | 1240 | if spec is not None or spec_set is not None: 1241 | if original is DEFAULT: 1242 | raise TypeError("Can't use 'spec' with create=True") 1243 | if isinstance(original, type): 1244 | # If we're patching out a class and there is a spec 1245 | inherit = True 1246 | 1247 | Klass = MagicMock 1248 | _kwargs = {} 1249 | if new_callable is not None: 1250 | Klass = new_callable 1251 | elif spec is not None or spec_set is not None: 1252 | this_spec = spec 1253 | if spec_set is not None: 1254 | this_spec = spec_set 1255 | if _is_list(this_spec): 1256 | not_callable = '__call__' not in this_spec 1257 | else: 1258 | not_callable = not callable(this_spec) 1259 | if not_callable: 1260 | Klass = NonCallableMagicMock 1261 | 1262 | if spec is not None: 1263 | _kwargs['spec'] = spec 1264 | if spec_set is not None: 1265 | _kwargs['spec_set'] = spec_set 1266 | 1267 | # add a name to mocks 1268 | if (isinstance(Klass, type) and 1269 | issubclass(Klass, NonCallableMock) and self.attribute): 1270 | _kwargs['name'] = self.attribute 1271 | 1272 | _kwargs.update(kwargs) 1273 | new = Klass(**_kwargs) 1274 | 1275 | if inherit and _is_instance_mock(new): 1276 | # we can only tell if the instance should be callable if the 1277 | # spec is not a list 1278 | this_spec = spec 1279 | if spec_set is not None: 1280 | this_spec = spec_set 1281 | if (not _is_list(this_spec) and not 1282 | _instance_callable(this_spec)): 1283 | Klass = NonCallableMagicMock 1284 | 1285 | _kwargs.pop('name') 1286 | new.return_value = Klass(_new_parent=new, _new_name='()', 1287 | **_kwargs) 1288 | elif autospec is not None: 1289 | # spec is ignored, new *must* be default, spec_set is treated 1290 | # as a boolean. Should we check spec is not None and that spec_set 1291 | # is a bool? 1292 | if new is not DEFAULT: 1293 | raise TypeError( 1294 | "autospec creates the mock for you. Can't specify " 1295 | "autospec and new." 1296 | ) 1297 | if original is DEFAULT: 1298 | raise TypeError("Can't use 'autospec' with create=True") 1299 | spec_set = bool(spec_set) 1300 | if autospec is True: 1301 | autospec = original 1302 | 1303 | new = create_autospec(autospec, spec_set=spec_set, 1304 | _name=self.attribute, **kwargs) 1305 | elif kwargs: 1306 | # can't set keyword args when we aren't creating the mock 1307 | # XXXX If new is a Mock we could call new.configure_mock(**kwargs) 1308 | raise TypeError("Can't pass kwargs to a mock we aren't creating") 1309 | 1310 | new_attr = new 1311 | 1312 | self.temp_original = original 1313 | self.is_local = local 1314 | setattr(self.target, self.attribute, new_attr) 1315 | if self.attribute_name is not None: 1316 | extra_args = {} 1317 | if self.new is DEFAULT: 1318 | extra_args[self.attribute_name] = new 1319 | for patching in self.additional_patchers: 1320 | arg = patching.__enter__() 1321 | if patching.new is DEFAULT: 1322 | extra_args.update(arg) 1323 | return extra_args 1324 | 1325 | return new 1326 | 1327 | 1328 | def __exit__(self, *exc_info): 1329 | """Undo the patch.""" 1330 | if not _is_started(self): 1331 | raise RuntimeError('stop called on unstarted patcher') 1332 | 1333 | if self.is_local and self.temp_original is not DEFAULT: 1334 | setattr(self.target, self.attribute, self.temp_original) 1335 | else: 1336 | delattr(self.target, self.attribute) 1337 | if not self.create and (not hasattr(self.target, self.attribute) or 1338 | self.attribute in ('__doc__', '__module__', 1339 | '__defaults__', '__annotations__', 1340 | '__kwdefaults__')): 1341 | # needed for proxy objects like django settings 1342 | setattr(self.target, self.attribute, self.temp_original) 1343 | 1344 | del self.temp_original 1345 | del self.is_local 1346 | del self.target 1347 | for patcher in reversed(self.additional_patchers): 1348 | if _is_started(patcher): 1349 | patcher.__exit__(*exc_info) 1350 | 1351 | 1352 | def start(self): 1353 | """Activate a patch, returning any created mock.""" 1354 | result = self.__enter__() 1355 | self._active_patches.append(self) 1356 | return result 1357 | 1358 | 1359 | def stop(self): 1360 | """Stop an active patch.""" 1361 | try: 1362 | self._active_patches.remove(self) 1363 | except ValueError: 1364 | # If the patch hasn't been started this will fail 1365 | pass 1366 | 1367 | return self.__exit__() 1368 | 1369 | 1370 | 1371 | def _get_target(target): 1372 | try: 1373 | target, attribute = target.rsplit('.', 1) 1374 | except (TypeError, ValueError): 1375 | raise TypeError("Need a valid target to patch. You supplied: %r" % 1376 | (target,)) 1377 | getter = lambda: _importer(target) 1378 | return getter, attribute 1379 | 1380 | 1381 | def _patch_object( 1382 | target, attribute, new=DEFAULT, spec=None, 1383 | create=False, spec_set=None, autospec=None, 1384 | new_callable=None, **kwargs 1385 | ): 1386 | """ 1387 | patch the named member (`attribute`) on an object (`target`) with a mock 1388 | object. 1389 | 1390 | `patch.object` can be used as a decorator, class decorator or a context 1391 | manager. Arguments `new`, `spec`, `create`, `spec_set`, 1392 | `autospec` and `new_callable` have the same meaning as for `patch`. Like 1393 | `patch`, `patch.object` takes arbitrary keyword arguments for configuring 1394 | the mock object it creates. 1395 | 1396 | When used as a class decorator `patch.object` honours `patch.TEST_PREFIX` 1397 | for choosing which methods to wrap. 1398 | """ 1399 | getter = lambda: target 1400 | return _patch( 1401 | getter, attribute, new, spec, create, 1402 | spec_set, autospec, new_callable, kwargs 1403 | ) 1404 | 1405 | 1406 | def _patch_multiple(target, spec=None, create=False, spec_set=None, 1407 | autospec=None, new_callable=None, **kwargs): 1408 | """Perform multiple patches in a single call. It takes the object to be 1409 | patched (either as an object or a string to fetch the object by importing) 1410 | and keyword arguments for the patches:: 1411 | 1412 | with patch.multiple(settings, FIRST_PATCH='one', SECOND_PATCH='two'): 1413 | ... 1414 | 1415 | Use `DEFAULT` as the value if you want `patch.multiple` to create 1416 | mocks for you. In this case the created mocks are passed into a decorated 1417 | function by keyword, and a dictionary is returned when `patch.multiple` is 1418 | used as a context manager. 1419 | 1420 | `patch.multiple` can be used as a decorator, class decorator or a context 1421 | manager. The arguments `spec`, `spec_set`, `create`, 1422 | `autospec` and `new_callable` have the same meaning as for `patch`. These 1423 | arguments will be applied to *all* patches done by `patch.multiple`. 1424 | 1425 | When used as a class decorator `patch.multiple` honours `patch.TEST_PREFIX` 1426 | for choosing which methods to wrap. 1427 | """ 1428 | if type(target) is str: 1429 | getter = lambda: _importer(target) 1430 | else: 1431 | getter = lambda: target 1432 | 1433 | if not kwargs: 1434 | raise ValueError( 1435 | 'Must supply at least one keyword argument with patch.multiple' 1436 | ) 1437 | # need to wrap in a list for python 3, where items is a view 1438 | items = list(kwargs.items()) 1439 | attribute, new = items[0] 1440 | patcher = _patch( 1441 | getter, attribute, new, spec, create, spec_set, 1442 | autospec, new_callable, {} 1443 | ) 1444 | patcher.attribute_name = attribute 1445 | for attribute, new in items[1:]: 1446 | this_patcher = _patch( 1447 | getter, attribute, new, spec, create, spec_set, 1448 | autospec, new_callable, {} 1449 | ) 1450 | this_patcher.attribute_name = attribute 1451 | patcher.additional_patchers.append(this_patcher) 1452 | return patcher 1453 | 1454 | 1455 | def patch( 1456 | target, new=DEFAULT, spec=None, create=False, 1457 | spec_set=None, autospec=None, new_callable=None, **kwargs 1458 | ): 1459 | """ 1460 | `patch` acts as a function decorator, class decorator or a context 1461 | manager. Inside the body of the function or with statement, the `target` 1462 | is patched with a `new` object. When the function/with statement exits 1463 | the patch is undone. 1464 | 1465 | If `new` is omitted, then the target is replaced with a 1466 | `MagicMock`. If `patch` is used as a decorator and `new` is 1467 | omitted, the created mock is passed in as an extra argument to the 1468 | decorated function. If `patch` is used as a context manager the created 1469 | mock is returned by the context manager. 1470 | 1471 | `target` should be a string in the form `'package.module.ClassName'`. The 1472 | `target` is imported and the specified object replaced with the `new` 1473 | object, so the `target` must be importable from the environment you are 1474 | calling `patch` from. The target is imported when the decorated function 1475 | is executed, not at decoration time. 1476 | 1477 | The `spec` and `spec_set` keyword arguments are passed to the `MagicMock` 1478 | if patch is creating one for you. 1479 | 1480 | In addition you can pass `spec=True` or `spec_set=True`, which causes 1481 | patch to pass in the object being mocked as the spec/spec_set object. 1482 | 1483 | `new_callable` allows you to specify a different class, or callable object, 1484 | that will be called to create the `new` object. By default `MagicMock` is 1485 | used. 1486 | 1487 | A more powerful form of `spec` is `autospec`. If you set `autospec=True` 1488 | then the mock will be created with a spec from the object being replaced. 1489 | All attributes of the mock will also have the spec of the corresponding 1490 | attribute of the object being replaced. Methods and functions being 1491 | mocked will have their arguments checked and will raise a `TypeError` if 1492 | they are called with the wrong signature. For mocks replacing a class, 1493 | their return value (the 'instance') will have the same spec as the class. 1494 | 1495 | Instead of `autospec=True` you can pass `autospec=some_object` to use an 1496 | arbitrary object as the spec instead of the one being replaced. 1497 | 1498 | By default `patch` will fail to replace attributes that don't exist. If 1499 | you pass in `create=True`, and the attribute doesn't exist, patch will 1500 | create the attribute for you when the patched function is called, and 1501 | delete it again afterwards. This is useful for writing tests against 1502 | attributes that your production code creates at runtime. It is off by 1503 | default because it can be dangerous. With it switched on you can write 1504 | passing tests against APIs that don't actually exist! 1505 | 1506 | Patch can be used as a `TestCase` class decorator. It works by 1507 | decorating each test method in the class. This reduces the boilerplate 1508 | code when your test methods share a common patchings set. `patch` finds 1509 | tests by looking for method names that start with `patch.TEST_PREFIX`. 1510 | By default this is `test`, which matches the way `unittest` finds tests. 1511 | You can specify an alternative prefix by setting `patch.TEST_PREFIX`. 1512 | 1513 | Patch can be used as a context manager, with the with statement. Here the 1514 | patching applies to the indented block after the with statement. If you 1515 | use "as" then the patched object will be bound to the name after the 1516 | "as"; very useful if `patch` is creating a mock object for you. 1517 | 1518 | `patch` takes arbitrary keyword arguments. These will be passed to 1519 | the `Mock` (or `new_callable`) on construction. 1520 | 1521 | `patch.dict(...)`, `patch.multiple(...)` and `patch.object(...)` are 1522 | available for alternate use-cases. 1523 | """ 1524 | getter, attribute = _get_target(target) 1525 | return _patch( 1526 | getter, attribute, new, spec, create, 1527 | spec_set, autospec, new_callable, kwargs 1528 | ) 1529 | 1530 | 1531 | class _patch_dict(object): 1532 | """ 1533 | Patch a dictionary, or dictionary like object, and restore the dictionary 1534 | to its original state after the test. 1535 | 1536 | `in_dict` can be a dictionary or a mapping like container. If it is a 1537 | mapping then it must at least support getting, setting and deleting items 1538 | plus iterating over keys. 1539 | 1540 | `in_dict` can also be a string specifying the name of the dictionary, which 1541 | will then be fetched by importing it. 1542 | 1543 | `values` can be a dictionary of values to set in the dictionary. `values` 1544 | can also be an iterable of `(key, value)` pairs. 1545 | 1546 | If `clear` is True then the dictionary will be cleared before the new 1547 | values are set. 1548 | 1549 | `patch.dict` can also be called with arbitrary keyword arguments to set 1550 | values in the dictionary:: 1551 | 1552 | with patch.dict('sys.modules', mymodule=Mock(), other_module=Mock()): 1553 | ... 1554 | 1555 | `patch.dict` can be used as a context manager, decorator or class 1556 | decorator. When used as a class decorator `patch.dict` honours 1557 | `patch.TEST_PREFIX` for choosing which methods to wrap. 1558 | """ 1559 | 1560 | def __init__(self, in_dict, values=(), clear=False, **kwargs): 1561 | if isinstance(in_dict, str): 1562 | in_dict = _importer(in_dict) 1563 | self.in_dict = in_dict 1564 | # support any argument supported by dict(...) constructor 1565 | self.values = dict(values) 1566 | self.values.update(kwargs) 1567 | self.clear = clear 1568 | self._original = None 1569 | 1570 | 1571 | def __call__(self, f): 1572 | if isinstance(f, type): 1573 | return self.decorate_class(f) 1574 | @wraps(f) 1575 | def _inner(*args, **kw): 1576 | self._patch_dict() 1577 | try: 1578 | return f(*args, **kw) 1579 | finally: 1580 | self._unpatch_dict() 1581 | 1582 | return _inner 1583 | 1584 | 1585 | def decorate_class(self, klass): 1586 | for attr in dir(klass): 1587 | attr_value = getattr(klass, attr) 1588 | if (attr.startswith(patch.TEST_PREFIX) and 1589 | hasattr(attr_value, "__call__")): 1590 | decorator = _patch_dict(self.in_dict, self.values, self.clear) 1591 | decorated = decorator(attr_value) 1592 | setattr(klass, attr, decorated) 1593 | return klass 1594 | 1595 | 1596 | def __enter__(self): 1597 | """Patch the dict.""" 1598 | self._patch_dict() 1599 | 1600 | 1601 | def _patch_dict(self): 1602 | values = self.values 1603 | in_dict = self.in_dict 1604 | clear = self.clear 1605 | 1606 | try: 1607 | original = in_dict.copy() 1608 | except AttributeError: 1609 | # dict like object with no copy method 1610 | # must support iteration over keys 1611 | original = {} 1612 | for key in in_dict: 1613 | original[key] = in_dict[key] 1614 | self._original = original 1615 | 1616 | if clear: 1617 | _clear_dict(in_dict) 1618 | 1619 | try: 1620 | in_dict.update(values) 1621 | except AttributeError: 1622 | # dict like object with no update method 1623 | for key in values: 1624 | in_dict[key] = values[key] 1625 | 1626 | 1627 | def _unpatch_dict(self): 1628 | in_dict = self.in_dict 1629 | original = self._original 1630 | 1631 | _clear_dict(in_dict) 1632 | 1633 | try: 1634 | in_dict.update(original) 1635 | except AttributeError: 1636 | for key in original: 1637 | in_dict[key] = original[key] 1638 | 1639 | 1640 | def __exit__(self, *args): 1641 | """Unpatch the dict.""" 1642 | self._unpatch_dict() 1643 | return False 1644 | 1645 | start = __enter__ 1646 | stop = __exit__ 1647 | 1648 | 1649 | def _clear_dict(in_dict): 1650 | try: 1651 | in_dict.clear() 1652 | except AttributeError: 1653 | keys = list(in_dict) 1654 | for key in keys: 1655 | del in_dict[key] 1656 | 1657 | 1658 | def _patch_stopall(): 1659 | """Stop all active patches. LIFO to unroll nested patches.""" 1660 | for patch in reversed(_patch._active_patches): 1661 | patch.stop() 1662 | 1663 | 1664 | patch.object = _patch_object 1665 | patch.dict = _patch_dict 1666 | patch.multiple = _patch_multiple 1667 | patch.stopall = _patch_stopall 1668 | patch.TEST_PREFIX = 'test' 1669 | 1670 | magic_methods = ( 1671 | "lt le gt ge eq ne " 1672 | "getitem setitem delitem " 1673 | "len contains iter " 1674 | "hash str sizeof " 1675 | "enter exit " 1676 | # we added divmod and rdivmod here instead of numerics 1677 | # because there is no idivmod 1678 | "divmod rdivmod neg pos abs invert " 1679 | "complex int float index " 1680 | "trunc floor ceil " 1681 | "bool next " 1682 | ) 1683 | 1684 | numerics = ( 1685 | "add sub mul matmul div floordiv mod lshift rshift and xor or pow truediv" 1686 | ) 1687 | inplace = ' '.join('i%s' % n for n in numerics.split()) 1688 | right = ' '.join('r%s' % n for n in numerics.split()) 1689 | 1690 | # not including __prepare__, __instancecheck__, __subclasscheck__ 1691 | # (as they are metaclass methods) 1692 | # __del__ is not supported at all as it causes problems if it exists 1693 | 1694 | _non_defaults = { 1695 | '__get__', '__set__', '__delete__', '__reversed__', '__missing__', 1696 | '__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__', 1697 | '__getstate__', '__setstate__', '__getformat__', '__setformat__', 1698 | '__repr__', '__dir__', '__subclasses__', '__format__', 1699 | '__getnewargs_ex__', 1700 | } 1701 | 1702 | 1703 | def _get_method(name, func): 1704 | "Turns a callable object (like a mock) into a real function" 1705 | def method(self, *args, **kw): 1706 | return func(self, *args, **kw) 1707 | method.__name__ = name 1708 | return method 1709 | 1710 | 1711 | _magics = { 1712 | '__%s__' % method for method in 1713 | ' '.join([magic_methods, numerics, inplace, right]).split() 1714 | } 1715 | 1716 | _all_magics = _magics | _non_defaults 1717 | 1718 | _unsupported_magics = { 1719 | '__getattr__', '__setattr__', 1720 | '__init__', '__new__', '__prepare__' 1721 | '__instancecheck__', '__subclasscheck__', 1722 | '__del__' 1723 | } 1724 | 1725 | _calculate_return_value = { 1726 | '__hash__': lambda self: object.__hash__(self), 1727 | '__str__': lambda self: object.__str__(self), 1728 | '__sizeof__': lambda self: object.__sizeof__(self), 1729 | } 1730 | 1731 | _return_values = { 1732 | '__lt__': NotImplemented, 1733 | '__gt__': NotImplemented, 1734 | '__le__': NotImplemented, 1735 | '__ge__': NotImplemented, 1736 | '__int__': 1, 1737 | '__contains__': False, 1738 | '__len__': 0, 1739 | '__exit__': False, 1740 | '__complex__': 1j, 1741 | '__float__': 1.0, 1742 | '__bool__': True, 1743 | '__index__': 1, 1744 | } 1745 | 1746 | 1747 | def _get_eq(self): 1748 | def __eq__(other): 1749 | ret_val = self.__eq__._mock_return_value 1750 | if ret_val is not DEFAULT: 1751 | return ret_val 1752 | if self is other: 1753 | return True 1754 | return NotImplemented 1755 | return __eq__ 1756 | 1757 | def _get_ne(self): 1758 | def __ne__(other): 1759 | if self.__ne__._mock_return_value is not DEFAULT: 1760 | return DEFAULT 1761 | if self is other: 1762 | return False 1763 | return NotImplemented 1764 | return __ne__ 1765 | 1766 | def _get_iter(self): 1767 | def __iter__(): 1768 | ret_val = self.__iter__._mock_return_value 1769 | if ret_val is DEFAULT: 1770 | return iter([]) 1771 | # if ret_val was already an iterator, then calling iter on it should 1772 | # return the iterator unchanged 1773 | return iter(ret_val) 1774 | return __iter__ 1775 | 1776 | _side_effect_methods = { 1777 | '__eq__': _get_eq, 1778 | '__ne__': _get_ne, 1779 | '__iter__': _get_iter, 1780 | } 1781 | 1782 | 1783 | 1784 | def _set_return_value(mock, method, name): 1785 | fixed = _return_values.get(name, DEFAULT) 1786 | if fixed is not DEFAULT: 1787 | method.return_value = fixed 1788 | return 1789 | 1790 | return_calulator = _calculate_return_value.get(name) 1791 | if return_calulator is not None: 1792 | try: 1793 | return_value = return_calulator(mock) 1794 | except AttributeError: 1795 | # XXXX why do we return AttributeError here? 1796 | # set it as a side_effect instead? 1797 | return_value = AttributeError(name) 1798 | method.return_value = return_value 1799 | return 1800 | 1801 | side_effector = _side_effect_methods.get(name) 1802 | if side_effector is not None: 1803 | method.side_effect = side_effector(mock) 1804 | 1805 | 1806 | 1807 | class MagicMixin(object): 1808 | def __init__(self, *args, **kw): 1809 | self._mock_set_magics() # make magic work for kwargs in init 1810 | _safe_super(MagicMixin, self).__init__(*args, **kw) 1811 | self._mock_set_magics() # fix magic broken by upper level init 1812 | 1813 | 1814 | def _mock_set_magics(self): 1815 | these_magics = _magics 1816 | 1817 | if getattr(self, "_mock_methods", None) is not None: 1818 | these_magics = _magics.intersection(self._mock_methods) 1819 | 1820 | remove_magics = set() 1821 | remove_magics = _magics - these_magics 1822 | 1823 | for entry in remove_magics: 1824 | if entry in type(self).__dict__: 1825 | # remove unneeded magic methods 1826 | delattr(self, entry) 1827 | 1828 | # don't overwrite existing attributes if called a second time 1829 | these_magics = these_magics - set(type(self).__dict__) 1830 | 1831 | _type = type(self) 1832 | for entry in these_magics: 1833 | setattr(_type, entry, MagicProxy(entry, self)) 1834 | 1835 | 1836 | 1837 | class NonCallableMagicMock(MagicMixin, NonCallableMock): 1838 | """A version of `MagicMock` that isn't callable.""" 1839 | def mock_add_spec(self, spec, spec_set=False): 1840 | """Add a spec to a mock. `spec` can either be an object or a 1841 | list of strings. Only attributes on the `spec` can be fetched as 1842 | attributes from the mock. 1843 | 1844 | If `spec_set` is True then only attributes on the spec can be set.""" 1845 | self._mock_add_spec(spec, spec_set) 1846 | self._mock_set_magics() 1847 | 1848 | 1849 | 1850 | class MagicMock(MagicMixin, Mock): 1851 | """ 1852 | MagicMock is a subclass of Mock with default implementations 1853 | of most of the magic methods. You can use MagicMock without having to 1854 | configure the magic methods yourself. 1855 | 1856 | If you use the `spec` or `spec_set` arguments then *only* magic 1857 | methods that exist in the spec will be created. 1858 | 1859 | Attributes and the return value of a `MagicMock` will also be `MagicMocks`. 1860 | """ 1861 | def mock_add_spec(self, spec, spec_set=False): 1862 | """Add a spec to a mock. `spec` can either be an object or a 1863 | list of strings. Only attributes on the `spec` can be fetched as 1864 | attributes from the mock. 1865 | 1866 | If `spec_set` is True then only attributes on the spec can be set.""" 1867 | self._mock_add_spec(spec, spec_set) 1868 | self._mock_set_magics() 1869 | 1870 | 1871 | 1872 | class MagicProxy(object): 1873 | def __init__(self, name, parent): 1874 | self.name = name 1875 | self.parent = parent 1876 | 1877 | def __call__(self, *args, **kwargs): 1878 | m = self.create_mock() 1879 | return m(*args, **kwargs) 1880 | 1881 | def create_mock(self): 1882 | entry = self.name 1883 | parent = self.parent 1884 | m = parent._get_child_mock(name=entry, _new_name=entry, 1885 | _new_parent=parent) 1886 | setattr(parent, entry, m) 1887 | _set_return_value(parent, m, entry) 1888 | return m 1889 | 1890 | def __get__(self, obj, _type=None): 1891 | return self.create_mock() 1892 | 1893 | 1894 | 1895 | class _ANY(object): 1896 | "A helper object that compares equal to everything." 1897 | 1898 | def __eq__(self, other): 1899 | return True 1900 | 1901 | def __ne__(self, other): 1902 | return False 1903 | 1904 | def __repr__(self): 1905 | return '' 1906 | 1907 | ANY = _ANY() 1908 | 1909 | 1910 | 1911 | def _format_call_signature(name, args, kwargs): 1912 | message = '%s(%%s)' % name 1913 | formatted_args = '' 1914 | args_string = ', '.join([repr(arg) for arg in args]) 1915 | kwargs_string = ', '.join([ 1916 | '%s=%r' % (key, value) for key, value in sorted(kwargs.items()) 1917 | ]) 1918 | if args_string: 1919 | formatted_args = args_string 1920 | if kwargs_string: 1921 | if formatted_args: 1922 | formatted_args += ', ' 1923 | formatted_args += kwargs_string 1924 | 1925 | return message % formatted_args 1926 | 1927 | 1928 | 1929 | class _Call(tuple): 1930 | """ 1931 | A tuple for holding the results of a call to a mock, either in the form 1932 | `(args, kwargs)` or `(name, args, kwargs)`. 1933 | 1934 | If args or kwargs are empty then a call tuple will compare equal to 1935 | a tuple without those values. This makes comparisons less verbose:: 1936 | 1937 | _Call(('name', (), {})) == ('name',) 1938 | _Call(('name', (1,), {})) == ('name', (1,)) 1939 | _Call(((), {'a': 'b'})) == ({'a': 'b'},) 1940 | 1941 | The `_Call` object provides a useful shortcut for comparing with call:: 1942 | 1943 | _Call(((1, 2), {'a': 3})) == call(1, 2, a=3) 1944 | _Call(('foo', (1, 2), {'a': 3})) == call.foo(1, 2, a=3) 1945 | 1946 | If the _Call has no name then it will match any name. 1947 | """ 1948 | def __new__(cls, value=(), name='', parent=None, two=False, 1949 | from_kall=True): 1950 | args = () 1951 | kwargs = {} 1952 | _len = len(value) 1953 | if _len == 3: 1954 | name, args, kwargs = value 1955 | elif _len == 2: 1956 | first, second = value 1957 | if isinstance(first, str): 1958 | name = first 1959 | if isinstance(second, tuple): 1960 | args = second 1961 | else: 1962 | kwargs = second 1963 | else: 1964 | args, kwargs = first, second 1965 | elif _len == 1: 1966 | value, = value 1967 | if isinstance(value, str): 1968 | name = value 1969 | elif isinstance(value, tuple): 1970 | args = value 1971 | else: 1972 | kwargs = value 1973 | 1974 | if two: 1975 | return tuple.__new__(cls, (args, kwargs)) 1976 | 1977 | return tuple.__new__(cls, (name, args, kwargs)) 1978 | 1979 | 1980 | def __init__(self, value=(), name=None, parent=None, two=False, 1981 | from_kall=True): 1982 | self.name = name 1983 | self.parent = parent 1984 | self.from_kall = from_kall 1985 | 1986 | 1987 | def __eq__(self, other): 1988 | if other is ANY: 1989 | return True 1990 | try: 1991 | len_other = len(other) 1992 | except TypeError: 1993 | return False 1994 | 1995 | self_name = '' 1996 | if len(self) == 2: 1997 | self_args, self_kwargs = self 1998 | else: 1999 | self_name, self_args, self_kwargs = self 2000 | 2001 | other_name = '' 2002 | if len_other == 0: 2003 | other_args, other_kwargs = (), {} 2004 | elif len_other == 3: 2005 | other_name, other_args, other_kwargs = other 2006 | elif len_other == 1: 2007 | value, = other 2008 | if isinstance(value, tuple): 2009 | other_args = value 2010 | other_kwargs = {} 2011 | elif isinstance(value, str): 2012 | other_name = value 2013 | other_args, other_kwargs = (), {} 2014 | else: 2015 | other_args = () 2016 | other_kwargs = value 2017 | elif len_other == 2: 2018 | # could be (name, args) or (name, kwargs) or (args, kwargs) 2019 | first, second = other 2020 | if isinstance(first, str): 2021 | other_name = first 2022 | if isinstance(second, tuple): 2023 | other_args, other_kwargs = second, {} 2024 | else: 2025 | other_args, other_kwargs = (), second 2026 | else: 2027 | other_args, other_kwargs = first, second 2028 | else: 2029 | return False 2030 | 2031 | if self_name and other_name != self_name: 2032 | return False 2033 | 2034 | # this order is important for ANY to work! 2035 | return (other_args, other_kwargs) == (self_args, self_kwargs) 2036 | 2037 | 2038 | __ne__ = object.__ne__ 2039 | 2040 | 2041 | def __call__(self, *args, **kwargs): 2042 | if self.name is None: 2043 | return _Call(('', args, kwargs), name='()') 2044 | 2045 | name = self.name + '()' 2046 | return _Call((self.name, args, kwargs), name=name, parent=self) 2047 | 2048 | 2049 | def __getattr__(self, attr): 2050 | if self.name is None: 2051 | return _Call(name=attr, from_kall=False) 2052 | name = '%s.%s' % (self.name, attr) 2053 | return _Call(name=name, parent=self, from_kall=False) 2054 | 2055 | 2056 | def count(self, *args, **kwargs): 2057 | return self.__getattr__('count')(*args, **kwargs) 2058 | 2059 | def index(self, *args, **kwargs): 2060 | return self.__getattr__('index')(*args, **kwargs) 2061 | 2062 | def __repr__(self): 2063 | if not self.from_kall: 2064 | name = self.name or 'call' 2065 | if name.startswith('()'): 2066 | name = 'call%s' % name 2067 | return name 2068 | 2069 | if len(self) == 2: 2070 | name = 'call' 2071 | args, kwargs = self 2072 | else: 2073 | name, args, kwargs = self 2074 | if not name: 2075 | name = 'call' 2076 | elif not name.startswith('()'): 2077 | name = 'call.%s' % name 2078 | else: 2079 | name = 'call%s' % name 2080 | return _format_call_signature(name, args, kwargs) 2081 | 2082 | 2083 | def call_list(self): 2084 | """For a call object that represents multiple calls, `call_list` 2085 | returns a list of all the intermediate calls as well as the 2086 | final call.""" 2087 | vals = [] 2088 | thing = self 2089 | while thing is not None: 2090 | if thing.from_kall: 2091 | vals.append(thing) 2092 | thing = thing.parent 2093 | return _CallList(reversed(vals)) 2094 | 2095 | 2096 | call = _Call(from_kall=False) 2097 | 2098 | 2099 | 2100 | def create_autospec(spec, spec_set=False, instance=False, _parent=None, 2101 | _name=None, **kwargs): 2102 | """Create a mock object using another object as a spec. Attributes on the 2103 | mock will use the corresponding attribute on the `spec` object as their 2104 | spec. 2105 | 2106 | Functions or methods being mocked will have their arguments checked 2107 | to check that they are called with the correct signature. 2108 | 2109 | If `spec_set` is True then attempting to set attributes that don't exist 2110 | on the spec object will raise an `AttributeError`. 2111 | 2112 | If a class is used as a spec then the return value of the mock (the 2113 | instance of the class) will have the same spec. You can use a class as the 2114 | spec for an instance object by passing `instance=True`. The returned mock 2115 | will only be callable if instances of the mock are callable. 2116 | 2117 | `create_autospec` also takes arbitrary keyword arguments that are passed to 2118 | the constructor of the created mock.""" 2119 | if _is_list(spec): 2120 | # can't pass a list instance to the mock constructor as it will be 2121 | # interpreted as a list of strings 2122 | spec = type(spec) 2123 | 2124 | is_type = isinstance(spec, type) 2125 | 2126 | _kwargs = {'spec': spec} 2127 | if spec_set: 2128 | _kwargs = {'spec_set': spec} 2129 | elif spec is None: 2130 | # None we mock with a normal mock without a spec 2131 | _kwargs = {} 2132 | if _kwargs and instance: 2133 | _kwargs['_spec_as_instance'] = True 2134 | 2135 | _kwargs.update(kwargs) 2136 | 2137 | Klass = MagicMock 2138 | if inspect.isdatadescriptor(spec): 2139 | # descriptors don't have a spec 2140 | # because we don't know what type they return 2141 | _kwargs = {} 2142 | elif not _callable(spec): 2143 | Klass = NonCallableMagicMock 2144 | elif is_type and instance and not _instance_callable(spec): 2145 | Klass = NonCallableMagicMock 2146 | 2147 | _name = _kwargs.pop('name', _name) 2148 | 2149 | _new_name = _name 2150 | if _parent is None: 2151 | # for a top level object no _new_name should be set 2152 | _new_name = '' 2153 | 2154 | mock = Klass(parent=_parent, _new_parent=_parent, _new_name=_new_name, 2155 | name=_name, **_kwargs) 2156 | 2157 | if isinstance(spec, FunctionTypes): 2158 | # should only happen at the top level because we don't 2159 | # recurse for functions 2160 | mock = _set_signature(mock, spec) 2161 | else: 2162 | _check_signature(spec, mock, is_type, instance) 2163 | 2164 | if _parent is not None and not instance: 2165 | _parent._mock_children[_name] = mock 2166 | 2167 | if is_type and not instance and 'return_value' not in kwargs: 2168 | mock.return_value = create_autospec(spec, spec_set, instance=True, 2169 | _name='()', _parent=mock) 2170 | 2171 | for entry in dir(spec): 2172 | if _is_magic(entry): 2173 | # MagicMock already does the useful magic methods for us 2174 | continue 2175 | 2176 | # XXXX do we need a better way of getting attributes without 2177 | # triggering code execution (?) Probably not - we need the actual 2178 | # object to mock it so we would rather trigger a property than mock 2179 | # the property descriptor. Likewise we want to mock out dynamically 2180 | # provided attributes. 2181 | # XXXX what about attributes that raise exceptions other than 2182 | # AttributeError on being fetched? 2183 | # we could be resilient against it, or catch and propagate the 2184 | # exception when the attribute is fetched from the mock 2185 | try: 2186 | original = getattr(spec, entry) 2187 | except AttributeError: 2188 | continue 2189 | 2190 | kwargs = {'spec': original} 2191 | if spec_set: 2192 | kwargs = {'spec_set': original} 2193 | 2194 | if not isinstance(original, FunctionTypes): 2195 | new = _SpecState(original, spec_set, mock, entry, instance) 2196 | mock._mock_children[entry] = new 2197 | else: 2198 | parent = mock 2199 | if isinstance(spec, FunctionTypes): 2200 | parent = mock.mock 2201 | 2202 | skipfirst = _must_skip(spec, entry, is_type) 2203 | kwargs['_eat_self'] = skipfirst 2204 | new = MagicMock(parent=parent, name=entry, _new_name=entry, 2205 | _new_parent=parent, 2206 | **kwargs) 2207 | mock._mock_children[entry] = new 2208 | _check_signature(original, new, skipfirst=skipfirst) 2209 | 2210 | # so functions created with _set_signature become instance attributes, 2211 | # *plus* their underlying mock exists in _mock_children of the parent 2212 | # mock. Adding to _mock_children may be unnecessary where we are also 2213 | # setting as an instance attribute? 2214 | if isinstance(new, FunctionTypes): 2215 | setattr(mock, entry, new) 2216 | 2217 | return mock 2218 | 2219 | 2220 | def _must_skip(spec, entry, is_type): 2221 | """ 2222 | Return whether we should skip the first argument on spec's `entry` 2223 | attribute. 2224 | """ 2225 | if not isinstance(spec, type): 2226 | if entry in getattr(spec, '__dict__', {}): 2227 | # instance attribute - shouldn't skip 2228 | return False 2229 | spec = spec.__class__ 2230 | 2231 | for klass in spec.__mro__: 2232 | result = klass.__dict__.get(entry, DEFAULT) 2233 | if result is DEFAULT: 2234 | continue 2235 | if isinstance(result, (staticmethod, classmethod)): 2236 | return False 2237 | elif isinstance(getattr(result, '__get__', None), MethodWrapperTypes): 2238 | # Normal method => skip if looked up on type 2239 | # (if looked up on instance, self is already skipped) 2240 | return is_type 2241 | else: 2242 | return False 2243 | 2244 | # shouldn't get here unless function is a dynamically provided attribute 2245 | # XXXX untested behaviour 2246 | return is_type 2247 | 2248 | 2249 | def _get_class(obj): 2250 | try: 2251 | return obj.__class__ 2252 | except AttributeError: 2253 | # it is possible for objects to have no __class__ 2254 | return type(obj) 2255 | 2256 | 2257 | class _SpecState(object): 2258 | 2259 | def __init__(self, spec, spec_set=False, parent=None, 2260 | name=None, ids=None, instance=False): 2261 | self.spec = spec 2262 | self.ids = ids 2263 | self.spec_set = spec_set 2264 | self.parent = parent 2265 | self.instance = instance 2266 | self.name = name 2267 | 2268 | 2269 | FunctionTypes = ( 2270 | # python function 2271 | type(create_autospec), 2272 | # instance method 2273 | type(ANY.__eq__), 2274 | ) 2275 | 2276 | MethodWrapperTypes = ( 2277 | type(ANY.__eq__.__get__), 2278 | ) 2279 | 2280 | 2281 | file_spec = None 2282 | 2283 | def _iterate_read_data(read_data): 2284 | # Helper for mock_open: 2285 | # Retrieve lines from read_data via a generator so that separate calls to 2286 | # readline, read, and readlines are properly interleaved 2287 | sep = b'\n' if isinstance(read_data, bytes) else '\n' 2288 | data_as_list = [l + sep for l in read_data.split(sep)] 2289 | 2290 | if data_as_list[-1] == sep: 2291 | # If the last line ended in a newline, the list comprehension will have an 2292 | # extra entry that's just a newline. Remove this. 2293 | data_as_list = data_as_list[:-1] 2294 | else: 2295 | # If there wasn't an extra newline by itself, then the file being 2296 | # emulated doesn't have a newline to end the last line remove the 2297 | # newline that our naive format() added 2298 | data_as_list[-1] = data_as_list[-1][:-1] 2299 | 2300 | for line in data_as_list: 2301 | yield line 2302 | 2303 | 2304 | def mock_open(mock=None, read_data=''): 2305 | """ 2306 | A helper function to create a mock to replace the use of `open`. It works 2307 | for `open` called directly or used as a context manager. 2308 | 2309 | The `mock` argument is the mock object to configure. If `None` (the 2310 | default) then a `MagicMock` will be created for you, with the API limited 2311 | to methods or attributes available on standard file handles. 2312 | 2313 | `read_data` is a string for the `read` methoddline`, and `readlines` of the 2314 | file handle to return. This is an empty string by default. 2315 | """ 2316 | def _readlines_side_effect(*args, **kwargs): 2317 | if handle.readlines.return_value is not None: 2318 | return handle.readlines.return_value 2319 | return list(_state[0]) 2320 | 2321 | def _read_side_effect(*args, **kwargs): 2322 | if handle.read.return_value is not None: 2323 | return handle.read.return_value 2324 | return type(read_data)().join(_state[0]) 2325 | 2326 | def _readline_side_effect(): 2327 | if handle.readline.return_value is not None: 2328 | while True: 2329 | yield handle.readline.return_value 2330 | for line in _state[0]: 2331 | yield line 2332 | while True: 2333 | yield type(read_data)() 2334 | 2335 | 2336 | global file_spec 2337 | if file_spec is None: 2338 | import _io 2339 | file_spec = list(set(dir(_io.TextIOWrapper)).union(set(dir(_io.BytesIO)))) 2340 | 2341 | if mock is None: 2342 | mock = MagicMock(name='open', spec=open) 2343 | 2344 | handle = MagicMock(spec=file_spec) 2345 | handle.__enter__.return_value = handle 2346 | 2347 | _state = [_iterate_read_data(read_data), None] 2348 | 2349 | handle.write.return_value = None 2350 | handle.read.return_value = None 2351 | handle.readline.return_value = None 2352 | handle.readlines.return_value = None 2353 | 2354 | handle.read.side_effect = _read_side_effect 2355 | _state[1] = _readline_side_effect() 2356 | handle.readline.side_effect = _state[1] 2357 | handle.readlines.side_effect = _readlines_side_effect 2358 | 2359 | def reset_data(*args, **kwargs): 2360 | _state[0] = _iterate_read_data(read_data) 2361 | if handle.readline.side_effect == _state[1]: 2362 | # Only reset the side effect if the user hasn't overridden it. 2363 | _state[1] = _readline_side_effect() 2364 | handle.readline.side_effect = _state[1] 2365 | return DEFAULT 2366 | 2367 | mock.side_effect = reset_data 2368 | mock.return_value = handle 2369 | return mock 2370 | 2371 | 2372 | class PropertyMock(Mock): 2373 | """ 2374 | A mock intended to be used as a property, or other descriptor, on a class. 2375 | `PropertyMock` provides `__get__` and `__set__` methods so you can specify 2376 | a return value when it is fetched. 2377 | 2378 | Fetching a `PropertyMock` instance from an object calls the mock, with 2379 | no args. Setting it calls the mock with the value being set. 2380 | """ 2381 | def _get_child_mock(self, **kwargs): 2382 | return MagicMock(**kwargs) 2383 | 2384 | def __get__(self, obj, obj_type): 2385 | return self() 2386 | def __set__(self, obj, val): 2387 | self(val) 2388 | -------------------------------------------------------------------------------- /fastunit/result.py: -------------------------------------------------------------------------------- 1 | """Test result object""" 2 | 3 | import io 4 | import sys 5 | import traceback 6 | 7 | from . import util 8 | from functools import wraps 9 | 10 | __unittest = True 11 | 12 | def failfast(method): 13 | @wraps(method) 14 | def inner(self, *args, **kw): 15 | if getattr(self, 'failfast', False): 16 | self.stop() 17 | return method(self, *args, **kw) 18 | return inner 19 | 20 | STDOUT_LINE = '\nStdout:\n%s' 21 | STDERR_LINE = '\nStderr:\n%s' 22 | 23 | 24 | class TestResult(object): 25 | """Holder for test result information. 26 | 27 | Test results are automatically managed by the TestCase and TestSuite 28 | classes, and do not need to be explicitly manipulated by writers of tests. 29 | 30 | Each instance holds the total number of tests run, and collections of 31 | failures and errors that occurred among those test runs. The collections 32 | contain tuples of (testcase, exceptioninfo), where exceptioninfo is the 33 | formatted traceback of the error that occurred. 34 | """ 35 | _previousTestClass = None 36 | _testRunEntered = False 37 | _moduleSetUpFailed = False 38 | def __init__(self, stream=None, descriptions=None, verbosity=None): 39 | self.failfast = False 40 | self.failures = [] 41 | self.errors = [] 42 | self.testsRun = 0 43 | self.skipped = [] 44 | self.expectedFailures = [] 45 | self.unexpectedSuccesses = [] 46 | self.shouldStop = False 47 | self.buffer = False 48 | self.tb_locals = False 49 | self._stdout_buffer = None 50 | self._stderr_buffer = None 51 | self._original_stdout = sys.stdout 52 | self._original_stderr = sys.stderr 53 | self._mirrorOutput = False 54 | 55 | def printErrors(self): 56 | "Called by TestRunner after test run" 57 | 58 | def startTest(self, test): 59 | "Called when the given test is about to be run" 60 | self.testsRun += 1 61 | self._mirrorOutput = False 62 | self._setupStdout() 63 | 64 | def _setupStdout(self): 65 | if self.buffer: 66 | if self._stderr_buffer is None: 67 | self._stderr_buffer = io.StringIO() 68 | self._stdout_buffer = io.StringIO() 69 | sys.stdout = self._stdout_buffer 70 | sys.stderr = self._stderr_buffer 71 | 72 | def startTestRun(self): 73 | """Called once before any tests are executed. 74 | 75 | See startTest for a method called before each test. 76 | """ 77 | 78 | def stopTest(self, test): 79 | """Called when the given test has been run""" 80 | self._restoreStdout() 81 | self._mirrorOutput = False 82 | 83 | def _restoreStdout(self): 84 | if self.buffer: 85 | if self._mirrorOutput: 86 | output = sys.stdout.getvalue() 87 | error = sys.stderr.getvalue() 88 | if output: 89 | if not output.endswith('\n'): 90 | output += '\n' 91 | self._original_stdout.write(STDOUT_LINE % output) 92 | if error: 93 | if not error.endswith('\n'): 94 | error += '\n' 95 | self._original_stderr.write(STDERR_LINE % error) 96 | 97 | sys.stdout = self._original_stdout 98 | sys.stderr = self._original_stderr 99 | self._stdout_buffer.seek(0) 100 | self._stdout_buffer.truncate() 101 | self._stderr_buffer.seek(0) 102 | self._stderr_buffer.truncate() 103 | 104 | def stopTestRun(self): 105 | """Called once after all tests are executed. 106 | 107 | See stopTest for a method called after each test. 108 | """ 109 | 110 | @failfast 111 | def addError(self, test, err): 112 | """Called when an error has occurred. 'err' is a tuple of values as 113 | returned by sys.exc_info(). 114 | """ 115 | self.errors.append((test, self._exc_info_to_string(err, test))) 116 | self._mirrorOutput = True 117 | 118 | @failfast 119 | def addFailure(self, test, err): 120 | """Called when an error has occurred. 'err' is a tuple of values as 121 | returned by sys.exc_info().""" 122 | self.failures.append((test, self._exc_info_to_string(err, test))) 123 | self._mirrorOutput = True 124 | 125 | def addSubTest(self, test, subtest, err): 126 | """Called at the end of a subtest. 127 | 'err' is None if the subtest ended successfully, otherwise it's a 128 | tuple of values as returned by sys.exc_info(). 129 | """ 130 | # By default, we don't do anything with successful subtests, but 131 | # more sophisticated test results might want to record them. 132 | if err is not None: 133 | if getattr(self, 'failfast', False): 134 | self.stop() 135 | if issubclass(err[0], test.failureException): 136 | errors = self.failures 137 | else: 138 | errors = self.errors 139 | errors.append((subtest, self._exc_info_to_string(err, test))) 140 | self._mirrorOutput = True 141 | 142 | def addSuccess(self, test): 143 | "Called when a test has completed successfully" 144 | pass 145 | 146 | def addSkip(self, test, reason): 147 | """Called when a test is skipped.""" 148 | self.skipped.append((test, reason)) 149 | 150 | def addExpectedFailure(self, test, err): 151 | """Called when an expected failure/error occurred.""" 152 | self.expectedFailures.append( 153 | (test, self._exc_info_to_string(err, test))) 154 | 155 | @failfast 156 | def addUnexpectedSuccess(self, test): 157 | """Called when a test was expected to fail, but succeed.""" 158 | self.unexpectedSuccesses.append(test) 159 | 160 | def wasSuccessful(self): 161 | """Tells whether or not this result was a success.""" 162 | # The hasattr check is for test_result's OldResult test. That 163 | # way this method works on objects that lack the attribute. 164 | # (where would such result intances come from? old stored pickles?) 165 | return ((len(self.failures) == len(self.errors) == 0) and 166 | (not hasattr(self, 'unexpectedSuccesses') or 167 | len(self.unexpectedSuccesses) == 0)) 168 | 169 | def stop(self): 170 | """Indicates that the tests should be aborted.""" 171 | self.shouldStop = True 172 | 173 | def _exc_info_to_string(self, err, test): 174 | """Converts a sys.exc_info()-style tuple of values into a string.""" 175 | exctype, value, tb = err 176 | # Skip test runner traceback levels 177 | while tb and self._is_relevant_tb_level(tb): 178 | tb = tb.tb_next 179 | 180 | if exctype is test.failureException: 181 | # Skip assert*() traceback levels 182 | length = self._count_relevant_tb_levels(tb) 183 | else: 184 | length = None 185 | tb_e = traceback.TracebackException( 186 | exctype, value, tb, limit=length, capture_locals=self.tb_locals) 187 | msgLines = list(tb_e.format()) 188 | 189 | if self.buffer: 190 | output = sys.stdout.getvalue() 191 | error = sys.stderr.getvalue() 192 | if output: 193 | if not output.endswith('\n'): 194 | output += '\n' 195 | msgLines.append(STDOUT_LINE % output) 196 | if error: 197 | if not error.endswith('\n'): 198 | error += '\n' 199 | msgLines.append(STDERR_LINE % error) 200 | return ''.join(msgLines) 201 | 202 | 203 | def _is_relevant_tb_level(self, tb): 204 | return '__unittest' in tb.tb_frame.f_globals 205 | 206 | def _count_relevant_tb_levels(self, tb): 207 | length = 0 208 | while tb and not self._is_relevant_tb_level(tb): 209 | length += 1 210 | tb = tb.tb_next 211 | return length 212 | 213 | def __repr__(self): 214 | return ("<%s run=%i errors=%i failures=%i>" % 215 | (util.strclass(self.__class__), self.testsRun, len(self.errors), 216 | len(self.failures))) 217 | -------------------------------------------------------------------------------- /fastunit/runner.py: -------------------------------------------------------------------------------- 1 | """Running tests""" 2 | 3 | import sys 4 | import time 5 | import warnings 6 | 7 | from . import result 8 | from .signals import registerResult 9 | 10 | __unittest = True 11 | 12 | 13 | class _WritelnDecorator(object): 14 | """Used to decorate file-like objects with a handy 'writeln' method""" 15 | def __init__(self,stream): 16 | self.stream = stream 17 | 18 | def __getattr__(self, attr): 19 | if attr in ('stream', '__getstate__'): 20 | raise AttributeError(attr) 21 | return getattr(self.stream,attr) 22 | 23 | def writeln(self, arg=None): 24 | if arg: 25 | self.write(arg) 26 | self.write('\n') # text-mode streams translate to \r\n if needed 27 | 28 | 29 | class TextTestResult(result.TestResult): 30 | """A test result class that can print formatted text results to a stream. 31 | 32 | Used by TextTestRunner. 33 | """ 34 | separator1 = '=' * 70 35 | separator2 = '-' * 70 36 | 37 | def __init__(self, stream, descriptions, verbosity): 38 | super(TextTestResult, self).__init__(stream, descriptions, verbosity) 39 | self.stream = stream 40 | self.showAll = verbosity > 1 41 | self.dots = verbosity == 1 42 | self.descriptions = descriptions 43 | 44 | def getDescription(self, test): 45 | doc_first_line = test.shortDescription() 46 | if self.descriptions and doc_first_line: 47 | return '\n'.join((str(test), doc_first_line)) 48 | else: 49 | return str(test) 50 | 51 | def startTest(self, test): 52 | super(TextTestResult, self).startTest(test) 53 | if self.showAll: 54 | self.stream.write(self.getDescription(test)) 55 | self.stream.write(" ... ") 56 | self.stream.flush() 57 | 58 | def addSuccess(self, test): 59 | super(TextTestResult, self).addSuccess(test) 60 | if self.showAll: 61 | self.stream.writeln("ok") 62 | elif self.dots: 63 | self.stream.write('.') 64 | self.stream.flush() 65 | 66 | def addError(self, test, err): 67 | super(TextTestResult, self).addError(test, err) 68 | if self.showAll: 69 | self.stream.writeln("ERROR") 70 | elif self.dots: 71 | self.stream.write('E') 72 | self.stream.flush() 73 | 74 | def addFailure(self, test, err): 75 | super(TextTestResult, self).addFailure(test, err) 76 | if self.showAll: 77 | self.stream.writeln("FAIL") 78 | elif self.dots: 79 | self.stream.write('F') 80 | self.stream.flush() 81 | 82 | def addSkip(self, test, reason): 83 | super(TextTestResult, self).addSkip(test, reason) 84 | if self.showAll: 85 | self.stream.writeln("skipped {0!r}".format(reason)) 86 | elif self.dots: 87 | self.stream.write("s") 88 | self.stream.flush() 89 | 90 | def addExpectedFailure(self, test, err): 91 | super(TextTestResult, self).addExpectedFailure(test, err) 92 | if self.showAll: 93 | self.stream.writeln("expected failure") 94 | elif self.dots: 95 | self.stream.write("x") 96 | self.stream.flush() 97 | 98 | def addUnexpectedSuccess(self, test): 99 | super(TextTestResult, self).addUnexpectedSuccess(test) 100 | if self.showAll: 101 | self.stream.writeln("unexpected success") 102 | elif self.dots: 103 | self.stream.write("u") 104 | self.stream.flush() 105 | 106 | def printErrors(self): 107 | if self.dots or self.showAll: 108 | self.stream.writeln() 109 | self.printErrorList('ERROR', self.errors) 110 | self.printErrorList('FAIL', self.failures) 111 | 112 | def printErrorList(self, flavour, errors): 113 | for test, err in errors: 114 | self.stream.writeln(self.separator1) 115 | self.stream.writeln("%s: %s" % (flavour,self.getDescription(test))) 116 | self.stream.writeln(self.separator2) 117 | self.stream.writeln("%s" % err) 118 | 119 | 120 | class TextTestRunner(object): 121 | """A test runner class that displays results in textual form. 122 | 123 | It prints out the names of tests as they are run, errors as they 124 | occur, and a summary of the results at the end of the test run. 125 | """ 126 | resultclass = TextTestResult 127 | 128 | def __init__(self, stream=None, descriptions=True, verbosity=1, 129 | failfast=False, buffer=False, resultclass=None, warnings=None, 130 | *, tb_locals=False): 131 | """Construct a TextTestRunner. 132 | 133 | Subclasses should accept **kwargs to ensure compatibility as the 134 | interface changes. 135 | """ 136 | if stream is None: 137 | stream = sys.stderr 138 | self.stream = _WritelnDecorator(stream) 139 | self.descriptions = descriptions 140 | self.verbosity = verbosity 141 | self.failfast = failfast 142 | self.buffer = buffer 143 | self.tb_locals = tb_locals 144 | self.warnings = warnings 145 | if resultclass is not None: 146 | self.resultclass = resultclass 147 | 148 | def _makeResult(self): 149 | return self.resultclass(self.stream, self.descriptions, self.verbosity) 150 | 151 | def run(self, test): 152 | "Run the given test case or test suite." 153 | result = self._makeResult() 154 | registerResult(result) 155 | result.failfast = self.failfast 156 | result.buffer = self.buffer 157 | result.tb_locals = self.tb_locals 158 | with warnings.catch_warnings(): 159 | if self.warnings: 160 | # if self.warnings is set, use it to filter all the warnings 161 | warnings.simplefilter(self.warnings) 162 | # if the filter is 'default' or 'always', special-case the 163 | # warnings from the deprecated unittest methods to show them 164 | # no more than once per module, because they can be fairly 165 | # noisy. The -Wd and -Wa flags can be used to bypass this 166 | # only when self.warnings is None. 167 | if self.warnings in ['default', 'always']: 168 | warnings.filterwarnings('module', 169 | category=DeprecationWarning, 170 | message='Please use assert\w+ instead.') 171 | startTime = time.time() 172 | startTestRun = getattr(result, 'startTestRun', None) 173 | if startTestRun is not None: 174 | startTestRun() 175 | try: 176 | test(result) 177 | finally: 178 | stopTestRun = getattr(result, 'stopTestRun', None) 179 | if stopTestRun is not None: 180 | stopTestRun() 181 | stopTime = time.time() 182 | timeTaken = stopTime - startTime 183 | result.printErrors() 184 | if hasattr(result, 'separator2'): 185 | self.stream.writeln(result.separator2) 186 | run = result.testsRun 187 | self.stream.writeln("Ran %d test%s in %.3fs" % 188 | (run, run != 1 and "s" or "", timeTaken)) 189 | self.stream.writeln() 190 | 191 | expectedFails = unexpectedSuccesses = skipped = 0 192 | try: 193 | results = map(len, (result.expectedFailures, 194 | result.unexpectedSuccesses, 195 | result.skipped)) 196 | except AttributeError: 197 | pass 198 | else: 199 | expectedFails, unexpectedSuccesses, skipped = results 200 | 201 | infos = [] 202 | if not result.wasSuccessful(): 203 | self.stream.write("FAILED") 204 | failed, errored = len(result.failures), len(result.errors) 205 | if failed: 206 | infos.append("failures=%d" % failed) 207 | if errored: 208 | infos.append("errors=%d" % errored) 209 | else: 210 | self.stream.write("OK") 211 | if skipped: 212 | infos.append("skipped=%d" % skipped) 213 | if expectedFails: 214 | infos.append("expected failures=%d" % expectedFails) 215 | if unexpectedSuccesses: 216 | infos.append("unexpected successes=%d" % unexpectedSuccesses) 217 | if infos: 218 | self.stream.writeln(" (%s)" % (", ".join(infos),)) 219 | else: 220 | self.stream.write("\n") 221 | return result 222 | -------------------------------------------------------------------------------- /fastunit/signals.py: -------------------------------------------------------------------------------- 1 | import signal 2 | import weakref 3 | 4 | from functools import wraps 5 | 6 | __unittest = True 7 | 8 | 9 | class _InterruptHandler(object): 10 | def __init__(self, default_handler): 11 | self.called = False 12 | self.original_handler = default_handler 13 | if isinstance(default_handler, int): 14 | if default_handler == signal.SIG_DFL: 15 | # Pretend it's signal.default_int_handler instead. 16 | default_handler = signal.default_int_handler 17 | elif default_handler == signal.SIG_IGN: 18 | # Not quite the same thing as SIG_IGN, but the closest we 19 | # can make it: do nothing. 20 | def default_handler(unused_signum, unused_frame): 21 | pass 22 | else: 23 | raise TypeError("expected SIGINT signal handler to be " 24 | "signal.SIG_IGN, signal.SIG_DFL, or a " 25 | "callable object") 26 | self.default_handler = default_handler 27 | 28 | def __call__(self, signum, frame): 29 | installed_handler = signal.getsignal(signal.SIGINT) 30 | if installed_handler is not self: 31 | # if we aren't the installed handler, then delegate immediately 32 | # to the default handler 33 | self.default_handler(signum, frame) 34 | 35 | if self.called: 36 | self.default_handler(signum, frame) 37 | self.called = True 38 | for result in _results.keys(): 39 | result.stop() 40 | 41 | _results = weakref.WeakKeyDictionary() 42 | def registerResult(result): 43 | _results[result] = 1 44 | 45 | def removeResult(result): 46 | return bool(_results.pop(result, None)) 47 | 48 | _interrupt_handler = None 49 | def installHandler(): 50 | global _interrupt_handler 51 | if _interrupt_handler is None: 52 | default_handler = signal.getsignal(signal.SIGINT) 53 | _interrupt_handler = _InterruptHandler(default_handler) 54 | signal.signal(signal.SIGINT, _interrupt_handler) 55 | 56 | 57 | def removeHandler(method=None): 58 | if method is not None: 59 | @wraps(method) 60 | def inner(*args, **kwargs): 61 | initial = signal.getsignal(signal.SIGINT) 62 | removeHandler() 63 | try: 64 | return method(*args, **kwargs) 65 | finally: 66 | signal.signal(signal.SIGINT, initial) 67 | return inner 68 | 69 | global _interrupt_handler 70 | if _interrupt_handler is not None: 71 | signal.signal(signal.SIGINT, _interrupt_handler.original_handler) 72 | -------------------------------------------------------------------------------- /fastunit/suite.py: -------------------------------------------------------------------------------- 1 | """TestSuite""" 2 | 3 | import sys 4 | import asyncio 5 | 6 | from . import case 7 | from . import util 8 | 9 | __unittest = True 10 | 11 | 12 | def _call_if_exists(parent, attr): 13 | func = getattr(parent, attr, lambda: None) 14 | func() 15 | 16 | 17 | class BaseTestSuite(object): 18 | """A simple test suite that doesn't provide class or module shared fixtures. 19 | """ 20 | _cleanup = True 21 | 22 | def __init__(self, tests=()): 23 | self._tests = [] 24 | self._removed_tests = 0 25 | self.addTests(tests) 26 | 27 | def __repr__(self): 28 | return "<%s tests=%s>" % (util.strclass(self.__class__), list(self)) 29 | 30 | def __eq__(self, other): 31 | if not isinstance(other, self.__class__): 32 | return NotImplemented 33 | return list(self) == list(other) 34 | 35 | def __iter__(self): 36 | return iter(self._tests) 37 | 38 | def countTestCases(self): 39 | cases = self._removed_tests 40 | for test in self: 41 | if test: 42 | cases += test.countTestCases() 43 | return cases 44 | 45 | def addTest(self, test): 46 | # sanity checks 47 | if not callable(test): 48 | raise TypeError("{} is not callable".format(repr(test))) 49 | if isinstance(test, type) and issubclass(test, 50 | (case.TestCase, TestSuite)): 51 | raise TypeError("TestCases and TestSuites must be instantiated " 52 | "before passing them to addTest()") 53 | self._tests.append(test) 54 | 55 | def addTests(self, tests): 56 | if isinstance(tests, str): 57 | raise TypeError("tests must be an iterable of tests, not a string") 58 | for test in tests: 59 | self.addTest(test) 60 | 61 | def run(self, result): 62 | for index, test in enumerate(self): 63 | if result.shouldStop: 64 | break 65 | test(result) 66 | if self._cleanup: 67 | self._removeTestAtIndex(index) 68 | return result 69 | 70 | def _removeTestAtIndex(self, index): 71 | """Stop holding a reference to the TestCase at index.""" 72 | try: 73 | test = self._tests[index] 74 | except TypeError: 75 | # support for suite implementations that have overridden self._tests 76 | pass 77 | else: 78 | # Some unittest tests add non TestCase/TestSuite objects to 79 | # the suite. 80 | if hasattr(test, 'countTestCases'): 81 | self._removed_tests += test.countTestCases() 82 | self._tests[index] = None 83 | 84 | def __call__(self, *args, **kwds): 85 | return self.run(*args, **kwds) 86 | 87 | def debug(self): 88 | """Run the tests without collecting errors in a TestResult""" 89 | for test in self: 90 | test.debug() 91 | 92 | 93 | class TestSuite(BaseTestSuite): 94 | """A test suite is a composite test consisting of a number of TestCases. 95 | 96 | For use, create an instance of TestSuite, then add test case instances. 97 | When all tests have been added, the suite can be passed to a test 98 | runner, such as TextTestRunner. It will run the individual test cases 99 | in the order in which they were added, aggregating the results. When 100 | subclassing, do not forget to call the base class constructor. 101 | """ 102 | 103 | def run(self, result, debug=False): 104 | topLevel = False 105 | if getattr(result, '_testRunEntered', False) is False: 106 | result._testRunEntered = topLevel = True 107 | # YX 108 | asyncMethod = [] 109 | # loop = asyncio.get_event_loop() 110 | loop = asyncio.new_event_loop() 111 | asyncio.set_event_loop(loop) 112 | for index, test in enumerate(self): 113 | asyncMethod.append(self.startRunCase(index, test, result)) 114 | if asyncMethod: 115 | loop.run_until_complete(asyncio.wait(asyncMethod)) 116 | loop.close() 117 | if topLevel: 118 | self._tearDownPreviousClass(None, result) 119 | self._handleModuleTearDown(result) 120 | result._testRunEntered = False 121 | return result 122 | 123 | # YX 124 | async def startRunCase(self, index, test, result): 125 | loop = asyncio.get_event_loop() 126 | if result.shouldStop: 127 | return False 128 | 129 | if _isnotsuite(test): 130 | self._tearDownPreviousClass(test, result) 131 | self._handleModuleFixture(test, result) 132 | self._handleClassSetUp(test, result) 133 | result._previousTestClass = test.__class__ 134 | 135 | if (getattr(test.__class__, '_classSetupFailed', False) or 136 | getattr(result, '_moduleSetUpFailed', False)): 137 | return True 138 | 139 | await loop.run_in_executor(None, test, result) 140 | 141 | if self._cleanup: 142 | self._removeTestAtIndex(index) 143 | 144 | def debug(self): 145 | """Run the tests without collecting errors in a TestResult""" 146 | debug = _DebugResult() 147 | self.run(debug, True) 148 | 149 | ################################ 150 | 151 | def _handleClassSetUp(self, test, result): 152 | previousClass = getattr(result, '_previousTestClass', None) 153 | currentClass = test.__class__ 154 | if currentClass == previousClass: 155 | return 156 | if result._moduleSetUpFailed: 157 | return 158 | if getattr(currentClass, "__unittest_skip__", False): 159 | return 160 | 161 | try: 162 | currentClass._classSetupFailed = False 163 | except TypeError: 164 | # test may actually be a function 165 | # so its class will be a builtin-type 166 | pass 167 | 168 | setUpClass = getattr(currentClass, 'setUpClass', None) 169 | if setUpClass is not None: 170 | _call_if_exists(result, '_setupStdout') 171 | try: 172 | setUpClass() 173 | except Exception as e: 174 | if isinstance(result, _DebugResult): 175 | raise 176 | currentClass._classSetupFailed = True 177 | className = util.strclass(currentClass) 178 | errorName = 'setUpClass (%s)' % className 179 | self._addClassOrModuleLevelException(result, e, errorName) 180 | finally: 181 | _call_if_exists(result, '_restoreStdout') 182 | 183 | def _get_previous_module(self, result): 184 | previousModule = None 185 | previousClass = getattr(result, '_previousTestClass', None) 186 | if previousClass is not None: 187 | previousModule = previousClass.__module__ 188 | return previousModule 189 | 190 | 191 | def _handleModuleFixture(self, test, result): 192 | previousModule = self._get_previous_module(result) 193 | currentModule = test.__class__.__module__ 194 | if currentModule == previousModule: 195 | return 196 | 197 | self._handleModuleTearDown(result) 198 | 199 | 200 | result._moduleSetUpFailed = False 201 | try: 202 | module = sys.modules[currentModule] 203 | except KeyError: 204 | return 205 | setUpModule = getattr(module, 'setUpModule', None) 206 | if setUpModule is not None: 207 | _call_if_exists(result, '_setupStdout') 208 | try: 209 | setUpModule() 210 | except Exception as e: 211 | if isinstance(result, _DebugResult): 212 | raise 213 | result._moduleSetUpFailed = True 214 | errorName = 'setUpModule (%s)' % currentModule 215 | self._addClassOrModuleLevelException(result, e, errorName) 216 | finally: 217 | _call_if_exists(result, '_restoreStdout') 218 | 219 | def _addClassOrModuleLevelException(self, result, exception, errorName): 220 | error = _ErrorHolder(errorName) 221 | addSkip = getattr(result, 'addSkip', None) 222 | if addSkip is not None and isinstance(exception, case.SkipTest): 223 | addSkip(error, str(exception)) 224 | else: 225 | result.addError(error, sys.exc_info()) 226 | 227 | def _handleModuleTearDown(self, result): 228 | previousModule = self._get_previous_module(result) 229 | if previousModule is None: 230 | return 231 | if result._moduleSetUpFailed: 232 | return 233 | 234 | try: 235 | module = sys.modules[previousModule] 236 | except KeyError: 237 | return 238 | 239 | tearDownModule = getattr(module, 'tearDownModule', None) 240 | if tearDownModule is not None: 241 | _call_if_exists(result, '_setupStdout') 242 | try: 243 | tearDownModule() 244 | except Exception as e: 245 | if isinstance(result, _DebugResult): 246 | raise 247 | errorName = 'tearDownModule (%s)' % previousModule 248 | self._addClassOrModuleLevelException(result, e, errorName) 249 | finally: 250 | _call_if_exists(result, '_restoreStdout') 251 | 252 | def _tearDownPreviousClass(self, test, result): 253 | previousClass = getattr(result, '_previousTestClass', None) 254 | currentClass = test.__class__ 255 | if currentClass == previousClass: 256 | return 257 | if getattr(previousClass, '_classSetupFailed', False): 258 | return 259 | if getattr(result, '_moduleSetUpFailed', False): 260 | return 261 | if getattr(previousClass, "__unittest_skip__", False): 262 | return 263 | 264 | tearDownClass = getattr(previousClass, 'tearDownClass', None) 265 | if tearDownClass is not None: 266 | _call_if_exists(result, '_setupStdout') 267 | try: 268 | tearDownClass() 269 | except Exception as e: 270 | if isinstance(result, _DebugResult): 271 | raise 272 | className = util.strclass(previousClass) 273 | errorName = 'tearDownClass (%s)' % className 274 | self._addClassOrModuleLevelException(result, e, errorName) 275 | finally: 276 | _call_if_exists(result, '_restoreStdout') 277 | 278 | 279 | class _ErrorHolder(object): 280 | """ 281 | Placeholder for a TestCase inside a result. As far as a TestResult 282 | is concerned, this looks exactly like a unit test. Used to insert 283 | arbitrary errors into a test suite run. 284 | """ 285 | # Inspired by the ErrorHolder from Twisted: 286 | # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py 287 | 288 | # attribute used by TestResult._exc_info_to_string 289 | failureException = None 290 | 291 | def __init__(self, description): 292 | self.description = description 293 | 294 | def id(self): 295 | return self.description 296 | 297 | def shortDescription(self): 298 | return None 299 | 300 | def __repr__(self): 301 | return "" % (self.description,) 302 | 303 | def __str__(self): 304 | return self.id() 305 | 306 | def run(self, result): 307 | # could call result.addError(...) - but this test-like object 308 | # shouldn't be run anyway 309 | pass 310 | 311 | def __call__(self, result): 312 | return self.run(result) 313 | 314 | def countTestCases(self): 315 | return 0 316 | 317 | def _isnotsuite(test): 318 | "A crude way to tell apart testcases and suites with duck-typing" 319 | try: 320 | iter(test) 321 | except TypeError: 322 | return True 323 | return False 324 | 325 | 326 | class _DebugResult(object): 327 | "Used by the TestSuite to hold previous class when running in debug." 328 | _previousTestClass = None 329 | _moduleSetUpFailed = False 330 | shouldStop = False 331 | -------------------------------------------------------------------------------- /fastunit/util.py: -------------------------------------------------------------------------------- 1 | """Various utility functions.""" 2 | 3 | from collections import namedtuple, OrderedDict 4 | from os.path import commonprefix 5 | 6 | __unittest = True 7 | 8 | _MAX_LENGTH = 80 9 | _PLACEHOLDER_LEN = 12 10 | _MIN_BEGIN_LEN = 5 11 | _MIN_END_LEN = 5 12 | _MIN_COMMON_LEN = 5 13 | _MIN_DIFF_LEN = _MAX_LENGTH - \ 14 | (_MIN_BEGIN_LEN + _PLACEHOLDER_LEN + _MIN_COMMON_LEN + 15 | _PLACEHOLDER_LEN + _MIN_END_LEN) 16 | assert _MIN_DIFF_LEN >= 0 17 | 18 | def _shorten(s, prefixlen, suffixlen): 19 | skip = len(s) - prefixlen - suffixlen 20 | if skip > _PLACEHOLDER_LEN: 21 | s = '%s[%d chars]%s' % (s[:prefixlen], skip, s[len(s) - suffixlen:]) 22 | return s 23 | 24 | def _common_shorten_repr(*args): 25 | args = tuple(map(safe_repr, args)) 26 | maxlen = max(map(len, args)) 27 | if maxlen <= _MAX_LENGTH: 28 | return args 29 | 30 | prefix = commonprefix(args) 31 | prefixlen = len(prefix) 32 | 33 | common_len = _MAX_LENGTH - \ 34 | (maxlen - prefixlen + _MIN_BEGIN_LEN + _PLACEHOLDER_LEN) 35 | if common_len > _MIN_COMMON_LEN: 36 | assert _MIN_BEGIN_LEN + _PLACEHOLDER_LEN + _MIN_COMMON_LEN + \ 37 | (maxlen - prefixlen) < _MAX_LENGTH 38 | prefix = _shorten(prefix, _MIN_BEGIN_LEN, common_len) 39 | return tuple(prefix + s[prefixlen:] for s in args) 40 | 41 | prefix = _shorten(prefix, _MIN_BEGIN_LEN, _MIN_COMMON_LEN) 42 | return tuple(prefix + _shorten(s[prefixlen:], _MIN_DIFF_LEN, _MIN_END_LEN) 43 | for s in args) 44 | 45 | def safe_repr(obj, short=False): 46 | try: 47 | result = repr(obj) 48 | except Exception: 49 | result = object.__repr__(obj) 50 | if not short or len(result) < _MAX_LENGTH: 51 | return result 52 | return result[:_MAX_LENGTH] + ' [truncated]...' 53 | 54 | def strclass(cls): 55 | return "%s.%s" % (cls.__module__, cls.__qualname__) 56 | 57 | def sorted_list_difference(expected, actual): 58 | """Finds elements in only one or the other of two, sorted input lists. 59 | 60 | Returns a two-element tuple of lists. The first list contains those 61 | elements in the "expected" list but not in the "actual" list, and the 62 | second contains those elements in the "actual" list but not in the 63 | "expected" list. Duplicate elements in either input list are ignored. 64 | """ 65 | i = j = 0 66 | missing = [] 67 | unexpected = [] 68 | while True: 69 | try: 70 | e = expected[i] 71 | a = actual[j] 72 | if e < a: 73 | missing.append(e) 74 | i += 1 75 | while expected[i] == e: 76 | i += 1 77 | elif e > a: 78 | unexpected.append(a) 79 | j += 1 80 | while actual[j] == a: 81 | j += 1 82 | else: 83 | i += 1 84 | try: 85 | while expected[i] == e: 86 | i += 1 87 | finally: 88 | j += 1 89 | while actual[j] == a: 90 | j += 1 91 | except IndexError: 92 | missing.extend(expected[i:]) 93 | unexpected.extend(actual[j:]) 94 | break 95 | return missing, unexpected 96 | 97 | 98 | def unorderable_list_difference(expected, actual): 99 | """Same behavior as sorted_list_difference but 100 | for lists of unorderable items (like dicts). 101 | 102 | As it does a linear search per item (remove) it 103 | has O(n*n) performance.""" 104 | missing = [] 105 | while expected: 106 | item = expected.pop() 107 | try: 108 | actual.remove(item) 109 | except ValueError: 110 | missing.append(item) 111 | 112 | # anything left in actual is unexpected 113 | return missing, actual 114 | 115 | def three_way_cmp(x, y): 116 | """Return -1 if x < y, 0 if x == y and 1 if x > y""" 117 | return (x > y) - (x < y) 118 | 119 | _Mismatch = namedtuple('Mismatch', 'actual expected value') 120 | 121 | def _count_diff_all_purpose(actual, expected): 122 | 'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ' 123 | # elements need not be hashable 124 | s, t = list(actual), list(expected) 125 | m, n = len(s), len(t) 126 | NULL = object() 127 | result = [] 128 | for i, elem in enumerate(s): 129 | if elem is NULL: 130 | continue 131 | cnt_s = cnt_t = 0 132 | for j in range(i, m): 133 | if s[j] == elem: 134 | cnt_s += 1 135 | s[j] = NULL 136 | for j, other_elem in enumerate(t): 137 | if other_elem == elem: 138 | cnt_t += 1 139 | t[j] = NULL 140 | if cnt_s != cnt_t: 141 | diff = _Mismatch(cnt_s, cnt_t, elem) 142 | result.append(diff) 143 | 144 | for i, elem in enumerate(t): 145 | if elem is NULL: 146 | continue 147 | cnt_t = 0 148 | for j in range(i, n): 149 | if t[j] == elem: 150 | cnt_t += 1 151 | t[j] = NULL 152 | diff = _Mismatch(0, cnt_t, elem) 153 | result.append(diff) 154 | return result 155 | 156 | def _ordered_count(iterable): 157 | 'Return dict of element counts, in the order they were first seen' 158 | c = OrderedDict() 159 | for elem in iterable: 160 | c[elem] = c.get(elem, 0) + 1 161 | return c 162 | 163 | def _count_diff_hashable(actual, expected): 164 | 'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ' 165 | # elements must be hashable 166 | s, t = _ordered_count(actual), _ordered_count(expected) 167 | result = [] 168 | for elem, cnt_s in s.items(): 169 | cnt_t = t.get(elem, 0) 170 | if cnt_s != cnt_t: 171 | diff = _Mismatch(cnt_s, cnt_t, elem) 172 | result.append(diff) 173 | for elem, cnt_t in t.items(): 174 | if elem not in s: 175 | diff = _Mismatch(0, cnt_t, elem) 176 | result.append(diff) 177 | return result 178 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | echo 'start remove build' 2 | rm -rf build/ 3 | echo 'end remove build' 4 | python setup.py install 5 | rm -rf build/ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'fastunit', 5 | version = '0.0.1', 6 | keywords='async unittest corotine', 7 | description = 'a library for running test cases asynchronously.', 8 | license = 'MIT License', 9 | url = 'https://github.com/ityoung/fastunit', 10 | author = 'Shin Yeung', 11 | author_email = 'ityoung@foxmail.com', 12 | packages = find_packages(), 13 | include_package_data = True, 14 | platforms = 'any', 15 | install_requires = [], 16 | ) 17 | 18 | classifiers = [ 19 | 'Development Status :: 3 - Alpha', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Operating System :: POSIX', 23 | 'Operating System :: Microsoft :: Windows', 24 | 'Operating System :: MacOS :: MacOS X', 25 | 'Topic :: Software Development :: Testing', 26 | 'Topic :: Software Development :: Libraries', 27 | 'Topic :: Utilities', 28 | ] + [ 29 | ('Programming Language :: Python :: %s' % x) 30 | for x in '3.5 3.6'.split() 31 | ] 32 | --------------------------------------------------------------------------------