Skip to content

Commit 41851aa

Browse files
author
OutlyingWest
committed
basic functionality to reading notebooks
1 parent 03754b2 commit 41851aa

1 file changed

Lines changed: 70 additions & 23 deletions

File tree

tests/test_kernel.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,40 +33,90 @@ def tearDownClass(cls) -> None:
3333

3434
def check_stream_output(self, code, expected_output, stream="stdout"):
3535
self.flush_channels()
36-
reply, output_msgs = self.execute_helper(code=code)
37-
for msg, expected_msg in zip(output_msgs, expected_output):
36+
reply, output_messages = self.execute_helper(code=code)
37+
from pprint import pprint
38+
39+
for expected_msg in expected_output:
3840
# replace env vars
3941
expected_msg = os.path.expandvars(expected_msg)
40-
# self.assertEqual(msg["header"]["msg_type"], "stream")
41-
# some messages can be of type 'execute_result'
42-
# type instead of stdout
43-
# self.assertEqual(msg["content"]["name"], stream)
42+
for msg in output_messages:
43+
self.extract_message_output(msg)
4444

45-
if msg["header"]["msg_type"] == "stream":
46-
# self.assertEqual(msg["content"]["name"], stream)
47-
self.assertEqual(
48-
clean_console_output(msg["content"]["text"]),
49-
clean_console_output(expected_msg)
50-
)
51-
elif msg["header"]["msg_type"] == "execute_result":
52-
self.assertEqual(
53-
clean_console_output(msg["content"]["data"]["text/plain"]),
54-
clean_console_output(expected_msg)
55-
)
5645

5746

58-
def check_from_file(self, filename):
5947

48+
def check_from_file(self, filename):
6049
with open(filename, "r") as file:
6150
cells = yaml.safe_load(file)
6251

6352
for idx, (code, expected_output) in enumerate(cells):
6453
with self.subTest(block=idx, code_line=code.splitlines()[0]):
6554
self.check_stream_output(code, expected_output)
6655

56+
def check_from_notebook(self, notebook_path: str):
57+
nb = nbformat.read(open(notebook_path), as_version=4)
58+
from pprint import pprint
59+
# pprint(nb.cells)
60+
61+
for idx, cell in enumerate(nb.cells):
62+
if cell.cell_type != "code":
63+
continue
64+
65+
cell_code = cell.source
66+
cell_outputs = cell.get("outputs", [])
67+
reply, output_messages = self.execute_helper(code=cell_code)
68+
69+
expected_outputs = self.extract_notebook_cell_outputs(cell_outputs)
70+
kernel_outputs = self.extract_kernel_executed_outputs(output_messages)
71+
72+
# print(idx)
73+
# pprint(expected_outputs)
74+
# print('---------------------------------')
75+
# pprint(kernel_outputs)
76+
# print()
77+
78+
# with self.subTest(cell=idx, code_line=cell_code.splitlines()[0] if cell_code.strip() else "<empty>"):
79+
# self.assertListEqual()
80+
81+
# with self.subTest(cell=idx, code_line=code.splitlines()[0] if code.strip() else "<empty>"):
82+
# self.check_stream_output(code, expected_outputs)
83+
84+
@staticmethod
85+
def extract_notebook_cell_outputs(cell_outputs: list) -> list:
86+
expected_outputs = []
87+
for output in cell_outputs:
88+
if output.output_type == "stream":
89+
message_text = output.get("text", "")
90+
elif output.output_type == "execute_result":
91+
message_text = output["data"].get("text/plain", "")
92+
elif output.output_type == "error":
93+
message_text = "\n".join(output["traceback"])
94+
else:
95+
message_text = ''
96+
message_text = message_text.strip()
97+
print(f'{message_text=}')
98+
expected_outputs.append(message_text)
99+
return expected_outputs
100+
101+
@staticmethod
102+
def extract_kernel_executed_outputs(output_messages: list) -> list:
103+
kernel_outputs = []
104+
for msg in output_messages:
105+
if msg["header"]["msg_type"] == "stream":
106+
message_text = msg["content"]["text"]
107+
elif msg["header"]["msg_type"] == "execute_result":
108+
message_text = msg["content"]["data"]["text/plain"]
109+
else:
110+
message_text = ''
111+
112+
if '\x00' not in message_text and '\r' not in message_text:
113+
kernel_outputs.append(message_text.strip())
114+
115+
return kernel_outputs
116+
67117
# Enumerate tests to ensure proper execution order
68118
def test_00_scorep_env(self):
69-
self.check_from_file("tests/kernel/scorep_env.yaml")
119+
self.check_from_notebook("tests/kernel/test_scorep_kernel.ipynb")
70120

71121
def test_01_scorep_pythonargs(self):
72122
self.check_from_file("tests/kernel/scorep_pythonargs.yaml")
@@ -87,9 +137,6 @@ def test_06_writemode(self):
87137
self.check_from_file("tests/kernel/writemode.yaml")
88138

89139

90-
def clean_console_output(text):
91-
return text.replace('\r', '').strip()
92-
93-
94140
if __name__ == "__main__":
141+
# KernelTests.check_from_notebook("tests/kernel/notebook.ipynb")
95142
unittest.main()

0 commit comments

Comments
 (0)