Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 65 additions & 74 deletions python-stdlib/unittest/unittest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import io
import sys

try:
import traceback
except ImportError:
traceback = None


class SkipTest(Exception):
pass
Expand All @@ -19,9 +14,12 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, tb):
if self.expected is None:
# Used by assertWarns, do nothing.
return
self.exception = exc_value
if exc_type is None:
assert False, "%r not raised" % self.expected
raise AssertionError("%r not raised" % self.expected)
if issubclass(exc_type, self.expected):
# store exception for later retrieval
self.exception = exc_value
Expand All @@ -42,8 +40,8 @@ def __init__(self, msg=None, params=None):
def __enter__(self):
pass

def __exit__(self, *exc_info):
if exc_info[0] is not None:
def __exit__(self, exc_type, exc_value, tb):
if exc_type is not None:
# Exception raised
global __test_result__, __current_test__
test_details = __current_test__
Expand All @@ -53,19 +51,11 @@ def __exit__(self, *exc_info):
detail = ", ".join("%s=%s" % k_v for k_v in self.params.items())
test_details += (" (%s)" % detail,)

_handle_test_exception(test_details, __test_result__, exc_info, False)
_handle_test_exception(test_details, __test_result__, exc_value, False)
# Suppress the exception as we've captured it above
return True


class NullContext:
def __enter__(self):
pass

def __exit__(self, exc_type, exc_value, traceback):
pass


class TestCase:
def __init__(self):
pass
Expand All @@ -88,27 +78,31 @@ def skipTest(self, reason):
raise SkipTest(reason)

def fail(self, msg=""):
assert False, msg
raise AssertionError(msg)

def assertEqual(self, x, y, msg=""):
if not msg:
msg = "%r vs (expected) %r" % (x, y)
assert x == y, msg
if not x == y:
if not msg:
msg = "%r vs (expected) %r" % (x, y)
raise AssertionError(msg)

def assertNotEqual(self, x, y, msg=""):
if not msg:
msg = "%r not expected to be equal %r" % (x, y)
assert x != y, msg
if not x != y:
if not msg:
msg = "%r not expected to be equal %r" % (x, y)
raise AssertionError(msg)

def assertLessEqual(self, x, y, msg=None):
if msg is None:
msg = "%r is expected to be <= %r" % (x, y)
assert x <= y, msg
if not x <= y:
if msg is None:
msg = "%r is expected to be <= %r" % (x, y)
raise AssertionError(msg)

def assertGreaterEqual(self, x, y, msg=None):
if msg is None:
msg = "%r is expected to be >= %r" % (x, y)
assert x >= y, msg
if not x >= y:
if msg is None:
msg = "%r is expected to be >= %r" % (x, y)
raise AssertionError(msg)

def assertAlmostEqual(self, x, y, places=None, msg="", delta=None):
if x == y:
Expand All @@ -129,7 +123,7 @@ def assertAlmostEqual(self, x, y, places=None, msg="", delta=None):
if not msg:
msg = "%r != %r within %r places" % (x, y, places)

assert False, msg
raise AssertionError(msg)

def assertNotAlmostEqual(self, x, y, places=None, msg="", delta=None):
if delta is not None and places is not None:
Expand All @@ -148,45 +142,53 @@ def assertNotAlmostEqual(self, x, y, places=None, msg="", delta=None):
if not msg:
msg = "%r == %r within %r places" % (x, y, places)

assert False, msg
raise AssertionError(msg)

def assertIs(self, x, y, msg=""):
if not msg:
msg = "%r is not %r" % (x, y)
assert x is y, msg
if not x is y:
if not msg:
msg = "%r is not %r" % (x, y)
raise AssertionError(msg)

def assertIsNot(self, x, y, msg=""):
if not msg:
msg = "%r is %r" % (x, y)
assert x is not y, msg
if not x is not y:
if not msg:
msg = "%r is %r" % (x, y)
raise AssertionError(msg)

def assertIsNone(self, x, msg=""):
if not msg:
msg = "%r is not None" % x
assert x is None, msg
if not x is None:
if not msg:
msg = "%r is not None" % x
raise AssertionError(msg)

def assertIsNotNone(self, x, msg=""):
if not msg:
msg = "%r is None" % x
assert x is not None, msg
if not x is not None:
if not msg:
msg = "%r is None" % x
raise AssertionError(msg)

def assertTrue(self, x, msg=""):
if not msg:
msg = "Expected %r to be True" % x
assert x, msg
if not x:
if not msg:
msg = "Expected %r to be True" % x
raise AssertionError(msg)

def assertFalse(self, x, msg=""):
if not msg:
msg = "Expected %r to be False" % x
assert not x, msg
if x:
if not msg:
msg = "Expected %r to be False" % x
raise AssertionError(msg)

def assertIn(self, x, y, msg=""):
if not msg:
msg = "Expected %r to be in %r" % (x, y)
assert x in y, msg
if not x in y:
if not msg:
msg = "Expected %r to be in %r" % (x, y)
raise AssertionError(msg)

def assertIsInstance(self, x, y, msg=""):
assert isinstance(x, y), msg
if not isinstance(x, y):
raise AssertionError(msg)

def assertRaises(self, exc, func=None, *args, **kwargs):
if func is None:
Expand All @@ -199,10 +201,10 @@ def assertRaises(self, exc, func=None, *args, **kwargs):
return
raise e

assert False, "%r not raised" % exc
raise AssertionError("%r not raised" % exc)

def assertWarns(self, warn):
return NullContext()
return AssertRaisesContext(None)


def skip(msg):
Expand Down Expand Up @@ -235,7 +237,7 @@ def test_exp_fail(*args, **kwargs):
except:
pass
else:
assert False, "unexpected success"
raise AssertionError("unexpected success")

return test_exp_fail

Expand Down Expand Up @@ -332,28 +334,19 @@ def __add__(self, other):
return self


def _capture_exc(exc, exc_traceback):
buf = io.StringIO()
if hasattr(sys, "print_exception"):
sys.print_exception(exc, buf)
elif traceback is not None:
traceback.print_exception(None, exc, exc_traceback, file=buf)
return buf.getvalue()


def _handle_test_exception(
current_test: tuple, test_result: TestResult, exc_info: tuple, verbose=True
current_test: tuple, test_result: TestResult, exc: Exception, verbose=True
):
exc = exc_info[1]
traceback = exc_info[2]
ex_str = _capture_exc(exc, traceback)
if isinstance(exc, SkipTest):
reason = exc.args[0]
test_result.skippedNum += 1
test_result.skipped.append((current_test, reason))
print(" skipped:", reason)
return
elif isinstance(exc, AssertionError):
buf = io.StringIO()
sys.print_exception(exc, buf)
ex_str = buf.getvalue()
if isinstance(exc, AssertionError):
test_result.failuresNum += 1
test_result.failures.append((current_test, ex_str))
if verbose:
Expand Down Expand Up @@ -402,9 +395,7 @@ def run_one(test_function):
else:
print(" ok")
except Exception as ex:
_handle_test_exception(
current_test=(name, c), test_result=test_result, exc_info=(type(ex), ex, None)
)
_handle_test_exception((name, c), test_result, ex)
# Uncomment to investigate failure in detail
# raise ex
finally:
Expand Down
Loading