diff --git a/python-stdlib/unittest/unittest/__init__.py b/python-stdlib/unittest/unittest/__init__.py index 8014e2828..0f3ba9cc3 100644 --- a/python-stdlib/unittest/unittest/__init__.py +++ b/python-stdlib/unittest/unittest/__init__.py @@ -1,11 +1,6 @@ import io import sys -try: - import traceback -except ImportError: - traceback = None - class SkipTest(Exception): pass @@ -19,9 +14,12 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, tb): + if self.expected is None: + # Used by assertWarns, do nothing. + return self.exception = exc_value if exc_type is None: - assert False, "%r not raised" % self.expected + raise AssertionError("%r not raised" % self.expected) if issubclass(exc_type, self.expected): # store exception for later retrieval self.exception = exc_value @@ -42,8 +40,8 @@ def __init__(self, msg=None, params=None): def __enter__(self): pass - def __exit__(self, *exc_info): - if exc_info[0] is not None: + def __exit__(self, exc_type, exc_value, tb): + if exc_type is not None: # Exception raised global __test_result__, __current_test__ test_details = __current_test__ @@ -53,19 +51,11 @@ def __exit__(self, *exc_info): detail = ", ".join("%s=%s" % k_v for k_v in self.params.items()) test_details += (" (%s)" % detail,) - _handle_test_exception(test_details, __test_result__, exc_info, False) + _handle_test_exception(test_details, __test_result__, exc_value, False) # Suppress the exception as we've captured it above return True -class NullContext: - def __enter__(self): - pass - - def __exit__(self, exc_type, exc_value, traceback): - pass - - class TestCase: def __init__(self): pass @@ -88,27 +78,31 @@ def skipTest(self, reason): raise SkipTest(reason) def fail(self, msg=""): - assert False, msg + raise AssertionError(msg) def assertEqual(self, x, y, msg=""): - if not msg: - msg = "%r vs (expected) %r" % (x, y) - assert x == y, msg + if not x == y: + if not msg: + msg = "%r vs (expected) %r" % (x, y) + raise AssertionError(msg) def assertNotEqual(self, x, y, msg=""): - if not msg: - msg = "%r not expected to be equal %r" % (x, y) - assert x != y, msg + if not x != y: + if not msg: + msg = "%r not expected to be equal %r" % (x, y) + raise AssertionError(msg) def assertLessEqual(self, x, y, msg=None): - if msg is None: - msg = "%r is expected to be <= %r" % (x, y) - assert x <= y, msg + if not x <= y: + if msg is None: + msg = "%r is expected to be <= %r" % (x, y) + raise AssertionError(msg) def assertGreaterEqual(self, x, y, msg=None): - if msg is None: - msg = "%r is expected to be >= %r" % (x, y) - assert x >= y, msg + if not x >= y: + if msg is None: + msg = "%r is expected to be >= %r" % (x, y) + raise AssertionError(msg) def assertAlmostEqual(self, x, y, places=None, msg="", delta=None): if x == y: @@ -129,7 +123,7 @@ def assertAlmostEqual(self, x, y, places=None, msg="", delta=None): if not msg: msg = "%r != %r within %r places" % (x, y, places) - assert False, msg + raise AssertionError(msg) def assertNotAlmostEqual(self, x, y, places=None, msg="", delta=None): if delta is not None and places is not None: @@ -148,45 +142,53 @@ def assertNotAlmostEqual(self, x, y, places=None, msg="", delta=None): if not msg: msg = "%r == %r within %r places" % (x, y, places) - assert False, msg + raise AssertionError(msg) def assertIs(self, x, y, msg=""): - if not msg: - msg = "%r is not %r" % (x, y) - assert x is y, msg + if not x is y: + if not msg: + msg = "%r is not %r" % (x, y) + raise AssertionError(msg) def assertIsNot(self, x, y, msg=""): - if not msg: - msg = "%r is %r" % (x, y) - assert x is not y, msg + if not x is not y: + if not msg: + msg = "%r is %r" % (x, y) + raise AssertionError(msg) def assertIsNone(self, x, msg=""): - if not msg: - msg = "%r is not None" % x - assert x is None, msg + if not x is None: + if not msg: + msg = "%r is not None" % x + raise AssertionError(msg) def assertIsNotNone(self, x, msg=""): - if not msg: - msg = "%r is None" % x - assert x is not None, msg + if not x is not None: + if not msg: + msg = "%r is None" % x + raise AssertionError(msg) def assertTrue(self, x, msg=""): - if not msg: - msg = "Expected %r to be True" % x - assert x, msg + if not x: + if not msg: + msg = "Expected %r to be True" % x + raise AssertionError(msg) def assertFalse(self, x, msg=""): - if not msg: - msg = "Expected %r to be False" % x - assert not x, msg + if x: + if not msg: + msg = "Expected %r to be False" % x + raise AssertionError(msg) def assertIn(self, x, y, msg=""): - if not msg: - msg = "Expected %r to be in %r" % (x, y) - assert x in y, msg + if not x in y: + if not msg: + msg = "Expected %r to be in %r" % (x, y) + raise AssertionError(msg) def assertIsInstance(self, x, y, msg=""): - assert isinstance(x, y), msg + if not isinstance(x, y): + raise AssertionError(msg) def assertRaises(self, exc, func=None, *args, **kwargs): if func is None: @@ -199,10 +201,10 @@ def assertRaises(self, exc, func=None, *args, **kwargs): return raise e - assert False, "%r not raised" % exc + raise AssertionError("%r not raised" % exc) def assertWarns(self, warn): - return NullContext() + return AssertRaisesContext(None) def skip(msg): @@ -235,7 +237,7 @@ def test_exp_fail(*args, **kwargs): except: pass else: - assert False, "unexpected success" + raise AssertionError("unexpected success") return test_exp_fail @@ -332,28 +334,19 @@ def __add__(self, other): return self -def _capture_exc(exc, exc_traceback): - buf = io.StringIO() - if hasattr(sys, "print_exception"): - sys.print_exception(exc, buf) - elif traceback is not None: - traceback.print_exception(None, exc, exc_traceback, file=buf) - return buf.getvalue() - - def _handle_test_exception( - current_test: tuple, test_result: TestResult, exc_info: tuple, verbose=True + current_test: tuple, test_result: TestResult, exc: Exception, verbose=True ): - exc = exc_info[1] - traceback = exc_info[2] - ex_str = _capture_exc(exc, traceback) if isinstance(exc, SkipTest): reason = exc.args[0] test_result.skippedNum += 1 test_result.skipped.append((current_test, reason)) print(" skipped:", reason) return - elif isinstance(exc, AssertionError): + buf = io.StringIO() + sys.print_exception(exc, buf) + ex_str = buf.getvalue() + if isinstance(exc, AssertionError): test_result.failuresNum += 1 test_result.failures.append((current_test, ex_str)) if verbose: @@ -402,9 +395,7 @@ def run_one(test_function): else: print(" ok") except Exception as ex: - _handle_test_exception( - current_test=(name, c), test_result=test_result, exc_info=(type(ex), ex, None) - ) + _handle_test_exception((name, c), test_result, ex) # Uncomment to investigate failure in detail # raise ex finally: