Skip to content

Commit 39d0914

Browse files
committed
Refactor dimension records logic
1 parent 0c7c857 commit 39d0914

1 file changed

Lines changed: 125 additions & 57 deletions

File tree

python/lsst/pipe/base/quantum_provenance_graph.py

Lines changed: 125 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def _add_quantum_info(
659659
self.recovered_quanta.append(dict(info["data_id"].required))
660660
if final_quantum_run is not None and final_quantum_run.caveats:
661661
code = final_quantum_run.caveats.concise()
662-
self.caveats.setdefault(code, []).append(dict(info["data_id"].required))
662+
self.caveats.setdefault(code, []).append(dict(info["data_id"].mapping))
663663
if final_quantum_run.caveats & QuantumSuccessCaveats.PARTIAL_OUTPUTS_ERROR:
664664
if final_quantum_run.exception is not None:
665665
self.exceptions.setdefault(final_quantum_run.exception.type_name, []).append(
@@ -964,7 +964,8 @@ def pprint(
964964
datasets: bool = True,
965965
show_exception_diagnostics: bool = False,
966966
butler: Butler | None = None,
967-
) -> None:
967+
return_exception_diagnostics_table: bool = False,
968+
) -> astropy.table.Table | None:
968969
"""Print this summary to stdout, as a series of tables.
969970
970971
Parameters
@@ -978,9 +979,21 @@ def pprint(
978979
includes a summary table of dataset counts for various status and
979980
(if ``brief`` is `True`) a table with per-data ID information for
980981
each unsuccessful or cursed dataset.
982+
show_exception_diagnostics : `bool`, optional
983+
If `True`, include a table of exception diagnostics in the output.
981984
butler : `lsst.daf.butler.Butler`, optional
982985
The butler used to create this summary. This is only used to get
983986
exposure dimension records for the exception diagnostics.
987+
return_exception_diagnostics_table : `bool`, optional
988+
If `True`, return the exception diagnostics table in addition to
989+
printing it. Only supported if ``show_exception_diagnostics`` is
990+
`True`.
991+
992+
Returns
993+
-------
994+
exception_diagnostics_table : `astropy.table.Table` or `None`
995+
A table of exception diagnostics, if requested and available.
996+
Otherwise, `None`.
984997
"""
985998
self.make_quantum_table().pprint_all()
986999
print("")
@@ -991,24 +1004,47 @@ def pprint(
9911004
if exception_table := self.make_exception_table():
9921005
exception_table.pprint_all()
9931006
print("")
1007+
exception_diagnostics_table = None
9941008
if show_exception_diagnostics:
995-
exception_diagnostics_table = self.make_exception_diagnostics_table(
996-
butler, max_message_width=45, shorten_type_name=True
1009+
if return_exception_diagnostics_table:
1010+
# Keep an original copy of the table to be returned.
1011+
exception_diagnostics_table = self.make_exception_diagnostics_table(butler)
1012+
exception_diagnostics_table_view = exception_diagnostics_table.copy()
1013+
else:
1014+
exception_diagnostics_table_view = self.make_exception_diagnostics_table(butler)
1015+
if exception_diagnostics_table_view:
1016+
# Shorten the exception type name by trimming the module name.
1017+
exception_diagnostics_table_view["Exception"] = [
1018+
val.rsplit(".", maxsplit=1)[-1] if val is not None else val
1019+
for val in exception_diagnostics_table_view["Exception"]
1020+
]
1021+
# Shorten the exception message to a maximum width.
1022+
max_message_width = 45
1023+
exception_diagnostics_table_view["Exception Message"] = [
1024+
textwrap.shorten(msg, width=max_message_width, placeholder="...")
1025+
if msg and isinstance(msg, str) and len(msg) > max_message_width
1026+
else msg
1027+
for msg in exception_diagnostics_table_view["Exception Message"]
1028+
]
1029+
with self.tty_buffer() as buffer:
1030+
# Use pprint() to trim long tables; pprint_all() may flood
1031+
# the screen in those cases.
1032+
exception_diagnostics_table_view.pprint()
1033+
last_line = buffer.getvalue().splitlines()[-1]
1034+
# Print the table from the buffer.
1035+
print(buffer.getvalue())
1036+
if "Length =" in last_line:
1037+
# The table was too long to print, we had to truncate it.
1038+
print(
1039+
"▲ Note: The exception diagnostics table above is truncated. "
1040+
"Use --exception-diagnostics-filename to save the complete table."
1041+
)
1042+
print("")
1043+
elif return_exception_diagnostics_table:
1044+
raise ValueError(
1045+
"The exception diagnostics table was requested to be returned, "
1046+
"but `show_exception_diagnostics` is False."
9971047
)
998-
with self.tty_buffer() as buffer:
999-
# Use pprint() to trim long tables; pprint_all() may flood the
1000-
# screen in those cases.
1001-
exception_diagnostics_table.pprint()
1002-
last_line = buffer.getvalue().splitlines()[-1]
1003-
# Print the table from the buffer.
1004-
print(buffer.getvalue())
1005-
if "Length =" in last_line:
1006-
# The table was too long to print, so we had to truncate it.
1007-
print(
1008-
"▲ Note: The exception diagnostics table above is truncated. "
1009-
"Use --exception-diagnostics-filename to save the complete table."
1010-
)
1011-
print("")
10121048
if datasets:
10131049
self.make_dataset_table().pprint_all()
10141050
print("")
@@ -1022,6 +1058,7 @@ def pprint(
10221058
print(f"{dataset_type_name} errors:")
10231059
bad_dataset_table.pprint_all()
10241060
print("")
1061+
return exception_diagnostics_table
10251062

10261063
def make_quantum_table(self) -> astropy.table.Table:
10271064
"""Construct an `astropy.table.Table` with a tabular summary of the
@@ -1129,6 +1166,7 @@ def make_exception_table(self) -> astropy.table.Table:
11291166
def make_exception_diagnostics_table(
11301167
self,
11311168
butler: Butler | None = None,
1169+
add_dimension_records: bool = True,
11321170
add_exception_msg: bool = True,
11331171
max_message_width: int | None = None,
11341172
shorten_type_name: bool = False,
@@ -1145,8 +1183,10 @@ def make_exception_diagnostics_table(
11451183
Parameters
11461184
----------
11471185
butler : `lsst.daf.butler.Butler`, optional
1148-
Butler instance used to fetch exposure records. If not provided,
1149-
exposure dimension records will not be included in the table.
1186+
Butler instance used to fetch dimension records.
1187+
add_dimension_records : `bool`, optional
1188+
If `True`, include visit and exposure dimension records in the
1189+
table, if available. This requires ``butler`` to be provided.
11501190
add_exception_msg : `bool`, optional
11511191
If `True`, include the exception message in the table.
11521192
max_message_width : `int`, optional
@@ -1165,31 +1205,32 @@ def make_exception_diagnostics_table(
11651205
task), and optionally, exposure dimension records and exception
11661206
messages.
11671207
"""
1168-
add_exposure_records = True
1169-
needed_exposure_records = ["day_obs", "physical_filter", "exposure_time", "target_name"]
1170-
1171-
# Preload all exposure dimension records up front for faster O(1)
1172-
# lookup later. Querying per data ID in the loop is painfully slow.
1173-
if butler:
1174-
exposure_record_lookup = {
1175-
d.dataId["exposure"]: d for d in butler.query_dimension_records("exposure", explain=False)
1176-
}
1177-
else:
1178-
exposure_record_lookup = {}
1179-
add_exposure_records = False
1208+
if add_dimension_records and butler is None:
1209+
raise ValueError("Butler is required to fetch dimension records.")
11801210

1181-
if butler and not exposure_record_lookup:
1182-
_LOG.warning("No exposure records found in the butler; they will not be included in the table.")
1183-
add_exposure_records = False
1211+
# The additional columns for visit and exposure records to add to the
1212+
# output table, if available. Note that 'band', 'day_obs', and
1213+
# 'physical_filter' already exist in `exception.data_id` below.
1214+
needed_visit_records = ["exposure_time", "target_name"]
1215+
needed_exposure_records = ["exposure_time", "target_name"]
11841216

11851217
rows: defaultdict[tuple, defaultdict[str, str]] = defaultdict(lambda: defaultdict(str))
1218+
exposure_data_ids: list[dict] = []
1219+
visit_data_ids: list[dict] = []
1220+
dimension_record_lookup: dict[str, DimensionRecord] = {}
11861221

11871222
# Loop over all tasks and exceptions, and associate them with data IDs.
11881223
for task_label, task_summary in self.tasks.items():
11891224
for type_name, exceptions in task_summary.exceptions.items():
11901225
for exception in exceptions:
1191-
data_id = exception.data_id
1192-
key = tuple(sorted(data_id.items())) # Hashable and stable
1226+
data_id = DataCoordinate.standardize(exception.data_id, universe=butler.dimensions)
1227+
if add_dimension_records:
1228+
if "visit" in data_id:
1229+
visit_data_ids.append(data_id)
1230+
elif "exposure" in data_id:
1231+
exposure_data_ids.append(data_id)
1232+
# Define a hashable and stable tuple of data ID values.
1233+
key = tuple(sorted(data_id.mapping.items()))
11931234
assert len(rows[key]) == 0, f"Multiple exceptions for one data ID: {key}"
11941235
assert rows[key]["Exception"] == "", f"Duplicate entry for data ID {key} in {task_label}"
11951236
if shorten_type_name:
@@ -1203,32 +1244,59 @@ def make_exception_diagnostics_table(
12031244
if max_message_width and len(msg) > max_message_width:
12041245
msg = textwrap.shorten(msg, max_message_width)
12051246
rows[key]["Exception Message"] = msg
1206-
if add_exposure_records:
1207-
exposure_record = exposure_record_lookup[data_id["exposure"]]
1208-
for k in needed_exposure_records:
1209-
rows[key][k] = getattr(exposure_record, k)
12101247

1211-
# Extract all unique columns.
1212-
all_columns = {col for r in rows.values() for col in r}
1213-
table_rows = []
1248+
if add_dimension_records and (visit_data_ids or exposure_data_ids):
1249+
# Preload all the dimension records up front for faster O(1) lookup
1250+
# later. Querying per data ID in the loop is painfully slow. These
1251+
# data IDs are limited to the ones that have exceptions.
1252+
with butler.query() as query:
1253+
query = query.join_data_coordinates(visit_data_ids + exposure_data_ids)
1254+
for element in ["visit", "exposure"]:
1255+
dimension_record_lookup |= {
1256+
f"{element}:{d.dataId[element]}": d for d in query.dimension_records(element)
1257+
}
1258+
1259+
# Loop over the data IDs and fill in the dimension records.
1260+
for element, data_ids, needed_records in zip(
1261+
["visit", "exposure"],
1262+
[visit_data_ids, exposure_data_ids],
1263+
[needed_visit_records, needed_exposure_records],
1264+
):
1265+
for data_id in data_ids:
1266+
key = tuple(sorted(data_id.mapping.items()))
1267+
for k in needed_records:
1268+
rows[key][k] = getattr(
1269+
dimension_record_lookup[f"{element}:{d.dataId[element]}"], k, None
1270+
)
1271+
1272+
# Extract all unique data ID keys from the rows for the table header.
1273+
all_key_columns = {k for key in rows for k, _ in key}
12141274

12151275
# Loop over all rows and add them to the table.
1216-
for key, col_counts in rows.items():
1217-
# Add data ID values as columns at the start of the row.
1218-
row = dict(key)
1219-
# Add exposure records next, if requested.
1220-
if add_exposure_records:
1221-
for col in needed_exposure_records:
1222-
row[col] = col_counts.get(col, "-")
1223-
# Add all other columns last.
1224-
for col in all_columns - set(needed_exposure_records) - {"Exception Message"}:
1225-
row[col] = col_counts.get(col, "-")
1226-
# Add the exception message if requested.
1276+
table_rows = []
1277+
for key, values in rows.items():
1278+
# Create a new row with all key columns initialized to None,
1279+
# allowing missing values to be properly masked when `masked=True`.
1280+
row = {col: None for col in all_key_columns}
1281+
# Fill in data ID fields from the key.
1282+
row.update(dict(key))
1283+
# Add dimension records next, if requested and available.
1284+
if add_dimension_records:
1285+
if visit_data_ids:
1286+
for col in needed_visit_records:
1287+
row[col] = values.get(col, None)
1288+
if exposure_data_ids:
1289+
for col in needed_exposure_records:
1290+
row[col] = values.get(col, None)
1291+
# Add task label and exception type.
1292+
for col in ("Task", "Exception"):
1293+
row[col] = values.get(col, None)
1294+
# Add the exception message, if requested.
12271295
if add_exception_msg:
1228-
row["Exception Message"] = col_counts.get("Exception Message", "-")
1296+
row["Exception Message"] = values.get("Exception Message", None)
12291297
table_rows.append(row)
12301298

1231-
return astropy.table.Table(table_rows)
1299+
return astropy.table.Table(table_rows, masked=True)
12321300

12331301
def make_bad_quantum_tables(self, max_message_width: int = 80) -> dict[str, astropy.table.Table]:
12341302
"""Construct an `astropy.table.Table` with per-data-ID information

0 commit comments

Comments
 (0)