|
1 | | -import asyncio |
2 | | -from collections.abc import Awaitable |
3 | | -from inspect import getmembers, isfunction |
4 | 1 | import sys |
5 | | -from typing import Callable, cast, Protocol, Set, Union |
| 2 | +from typing import List |
6 | 3 |
|
7 | | -from rich.prompt import Prompt |
| 4 | +from pytest import console_main |
8 | 5 |
|
9 | | -"""Tools for manual testing with real hardware.""" |
10 | 6 |
|
| 7 | +def parse_args(raw_args: List[str] = sys.argv[1:]) -> List[str]: |
| 8 | + """ |
| 9 | + Given arguments for pytest, ensure that -s or --capture=no is set. |
11 | 10 |
|
12 | | -class AbortError(Exception): |
13 | | - """A manual testing step has been aborted.""" |
| 11 | + This function is not entirely robust, as it doesn't implement full arguments |
| 12 | + parsing as in pytest. If an option is passed "-s" or "--capture=no" as a value, |
| 13 | + this will fail to add the appropriate flag. But those cases should be very |
| 14 | + rare, and a full options parse is difficult. |
| 15 | + """ |
14 | 16 |
|
15 | | - pass |
| 17 | + args: List[str] = list() |
16 | 18 |
|
| 19 | + has_no_capture_flag = False |
| 20 | + for i, arg in enumerate(raw_args): |
| 21 | + if arg == "--capture=no" or arg == "-s": |
| 22 | + has_no_capture_flag = True |
| 23 | + args += raw_args[i:] |
| 24 | + break |
17 | 25 |
|
18 | | -def confirm(text: str) -> None: |
19 | | - """Manually confirm an expected state.""" |
20 | | - |
21 | | - res = Prompt.ask(text, choices=["confirm", "abort"]) |
22 | | - |
23 | | - if res == "abort": |
24 | | - raise AbortError("Aborted.") |
25 | | - |
26 | | - |
27 | | -def take_action(text: str) -> None: |
28 | | - """Take a manual action before continuing.""" |
29 | | - |
30 | | - res = Prompt.ask(text, choices=["continue", "abort"]) |
31 | | - |
32 | | - if res == "abort": |
33 | | - raise AbortError("Aborted.") |
34 | | - |
35 | | - |
36 | | -def check(text: str, expected: str) -> None: |
37 | | - """Manually check whether or not an expected state is so.""" |
38 | | - |
39 | | - res = Prompt.ask(text, choices=["yes", "no", "abort"]) |
40 | | - |
41 | | - if res == "abort": |
42 | | - raise AbortError("Aborted.") |
43 | | - |
44 | | - assert res == "yes", expected |
45 | | - |
46 | | - |
47 | | -class MarkedTest(Protocol): |
48 | | - marks: Set[str] |
49 | | - |
50 | | - def __call__(self) -> Awaitable[None]: ... |
51 | | - |
52 | | - |
53 | | -UnmarkedTest = Callable[[], Awaitable[None]] |
54 | | - |
55 | | -Test = Union[UnmarkedTest, MarkedTest] |
56 | | - |
57 | | - |
58 | | -def mark(tag: str) -> Callable[[Test], MarkedTest]: |
59 | | - def decorator(test: Test) -> MarkedTest: |
60 | | - marked = cast(MarkedTest, test) |
61 | | - |
62 | | - if not hasattr(test, "marks"): |
63 | | - marked.marks = set() |
64 | | - |
65 | | - marked.marks.add(tag) |
66 | | - |
67 | | - return marked |
68 | | - |
69 | | - return decorator |
70 | | - |
71 | | - |
72 | | -def skip(test: Test) -> MarkedTest: |
73 | | - """Skip a test.""" |
74 | | - |
75 | | - return mark("skip")(test) |
76 | | - |
| 26 | + if arg.startswith("--capture="): |
| 27 | + # Drop the flag, since we're going to override it later |
| 28 | + continue |
77 | 29 |
|
78 | | -def marked_with(tag: str, test: Test) -> bool: |
79 | | - """Check if a test has a mark.""" |
| 30 | + args.append(arg) |
80 | 31 |
|
81 | | - if not hasattr(test, "marks"): |
82 | | - return False |
| 32 | + if not has_no_capture_flag: |
| 33 | + args.insert(0, "--capture=no") |
83 | 34 |
|
84 | | - return tag in cast(MarkedTest, test).marks |
| 35 | + return args |
85 | 36 |
|
86 | 37 |
|
87 | | -async def _run_tests(__name__: str) -> None: |
88 | | - for name, test in getmembers(sys.modules[__name__], isfunction): |
89 | | - if not name.startswith("test_"): |
90 | | - continue |
| 38 | +def main() -> int: |
| 39 | + """ |
| 40 | + A command line entry point that calls pytest with --capture=no set. |
| 41 | + """ |
91 | 42 |
|
92 | | - if marked_with("skip", test): |
93 | | - print(f"=== {name} SKIPPED ===") |
94 | | - continue |
| 43 | + args = parse_args() |
| 44 | + # sys.argv[0] is typically the command that was run |
| 45 | + args.insert(0, "pytest") |
95 | 46 |
|
96 | | - print(f"=== {name} ===") |
97 | | - try: |
98 | | - await test() |
99 | | - except Exception as exc: |
100 | | - print(f"{name} FAILED") |
101 | | - print(exc) |
102 | | - else: |
103 | | - print(f"=== {name} PASSED ===") |
| 47 | + # Patch sys.argv so the pytest entry point picks it up |
| 48 | + sys.argv = args |
104 | 49 |
|
| 50 | + # Call the standard pytest entry point with modified args |
| 51 | + return console_main() |
105 | 52 |
|
106 | | -def run_tests(__name__: str) -> None: |
107 | | - """Run integration tests in module.""" |
108 | 53 |
|
109 | | - loop = asyncio.get_event_loop() |
110 | | - loop.run_until_complete(_run_tests(__name__)) |
| 54 | +if __name__ == "__main__": |
| 55 | + sys.exit(main()) |
0 commit comments