Skip to content

Commit 0b5f032

Browse files
committed
Refactor dimension records logic
1 parent 0c7c857 commit 0b5f032

1 file changed

Lines changed: 126 additions & 57 deletions

File tree

python/lsst/pipe/base/quantum_provenance_graph.py

Lines changed: 126 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
MissingDatasetTypeError,
7474
QuantumBackedButler,
7575
)
76+
from lsst.daf.butler.registry.queries import DimensionRecordQueryResults
7677
from lsst.resources import ResourcePathExpression
7778
from lsst.utils.logging import PeriodicLogger, getLogger
7879

@@ -659,7 +660,7 @@ def _add_quantum_info(
659660
self.recovered_quanta.append(dict(info["data_id"].required))
660661
if final_quantum_run is not None and final_quantum_run.caveats:
661662
code = final_quantum_run.caveats.concise()
662-
self.caveats.setdefault(code, []).append(dict(info["data_id"].required))
663+
self.caveats.setdefault(code, []).append(dict(info["data_id"].mapping))
663664
if final_quantum_run.caveats & QuantumSuccessCaveats.PARTIAL_OUTPUTS_ERROR:
664665
if final_quantum_run.exception is not None:
665666
self.exceptions.setdefault(final_quantum_run.exception.type_name, []).append(
@@ -964,7 +965,8 @@ def pprint(
964965
datasets: bool = True,
965966
show_exception_diagnostics: bool = False,
966967
butler: Butler | None = None,
967-
) -> None:
968+
return_exception_diagnostics_table: bool = False,
969+
) -> astropy.table.Table | None:
968970
"""Print this summary to stdout, as a series of tables.
969971
970972
Parameters
@@ -978,9 +980,21 @@ def pprint(
978980
includes a summary table of dataset counts for various status and
979981
(if ``brief`` is `True`) a table with per-data ID information for
980982
each unsuccessful or cursed dataset.
983+
show_exception_diagnostics : `bool`, optional
984+
If `True`, include a table of exception diagnostics in the output.
981985
butler : `lsst.daf.butler.Butler`, optional
982986
The butler used to create this summary. This is only used to get
983987
exposure dimension records for the exception diagnostics.
988+
return_exception_diagnostics_table : `bool`, optional
989+
If `True`, return the exception diagnostics table in addition to
990+
printing it. Only supported if ``show_exception_diagnostics`` is
991+
`True`.
992+
993+
Returns
994+
-------
995+
exception_diagnostics_table : `astropy.table.Table` or `None`
996+
A table of exception diagnostics, if requested and available.
997+
Otherwise, `None`.
984998
"""
985999
self.make_quantum_table().pprint_all()
9861000
print("")
@@ -991,24 +1005,47 @@ def pprint(
9911005
if exception_table := self.make_exception_table():
9921006
exception_table.pprint_all()
9931007
print("")
1008+
exception_diagnostics_table = None
9941009
if show_exception_diagnostics:
995-
exception_diagnostics_table = self.make_exception_diagnostics_table(
996-
butler, max_message_width=45, shorten_type_name=True
1010+
if return_exception_diagnostics_table:
1011+
# Keep an original copy of the table to be returned.
1012+
exception_diagnostics_table = self.make_exception_diagnostics_table(butler)
1013+
exception_diagnostics_table_view = exception_diagnostics_table.copy()
1014+
else:
1015+
exception_diagnostics_table_view = self.make_exception_diagnostics_table(butler)
1016+
if exception_diagnostics_table_view:
1017+
# Shorten the exception type name by trimming the module name.
1018+
exception_diagnostics_table_view["Exception"] = [
1019+
val.rsplit(".", maxsplit=1)[-1] if val is not None else val
1020+
for val in exception_diagnostics_table_view["Exception"]
1021+
]
1022+
# Shorten the exception message to a maximum width.
1023+
max_message_width = 45
1024+
exception_diagnostics_table_view["Exception Message"] = [
1025+
textwrap.shorten(msg, width=max_message_width, placeholder="...")
1026+
if msg and isinstance(msg, str) and len(msg) > max_message_width
1027+
else msg
1028+
for msg in exception_diagnostics_table_view["Exception Message"]
1029+
]
1030+
with self.tty_buffer() as buffer:
1031+
# Use pprint() to trim long tables; pprint_all() may flood
1032+
# the screen in those cases.
1033+
exception_diagnostics_table_view.pprint()
1034+
last_line = buffer.getvalue().splitlines()[-1]
1035+
# Print the table from the buffer.
1036+
print(buffer.getvalue())
1037+
if "Length =" in last_line:
1038+
# The table was too long to print, we had to truncate it.
1039+
print(
1040+
"▲ Note: The exception diagnostics table above is truncated. "
1041+
"Use --exception-diagnostics-filename to save the complete table."
1042+
)
1043+
print("")
1044+
elif return_exception_diagnostics_table:
1045+
raise ValueError(
1046+
"The exception diagnostics table was requested to be returned, "
1047+
"but `show_exception_diagnostics` is False."
9971048
)
998-
with self.tty_buffer() as buffer:
999-
# Use pprint() to trim long tables; pprint_all() may flood the
1000-
# screen in those cases.
1001-
exception_diagnostics_table.pprint()
1002-
last_line = buffer.getvalue().splitlines()[-1]
1003-
# Print the table from the buffer.
1004-
print(buffer.getvalue())
1005-
if "Length =" in last_line:
1006-
# The table was too long to print, so we had to truncate it.
1007-
print(
1008-
"▲ Note: The exception diagnostics table above is truncated. "
1009-
"Use --exception-diagnostics-filename to save the complete table."
1010-
)
1011-
print("")
10121049
if datasets:
10131050
self.make_dataset_table().pprint_all()
10141051
print("")
@@ -1022,6 +1059,7 @@ def pprint(
10221059
print(f"{dataset_type_name} errors:")
10231060
bad_dataset_table.pprint_all()
10241061
print("")
1062+
return exception_diagnostics_table
10251063

10261064
def make_quantum_table(self) -> astropy.table.Table:
10271065
"""Construct an `astropy.table.Table` with a tabular summary of the
@@ -1129,6 +1167,7 @@ def make_exception_table(self) -> astropy.table.Table:
11291167
def make_exception_diagnostics_table(
11301168
self,
11311169
butler: Butler | None = None,
1170+
add_dimension_records: bool = True,
11321171
add_exception_msg: bool = True,
11331172
max_message_width: int | None = None,
11341173
shorten_type_name: bool = False,
@@ -1145,8 +1184,10 @@ def make_exception_diagnostics_table(
11451184
Parameters
11461185
----------
11471186
butler : `lsst.daf.butler.Butler`, optional
1148-
Butler instance used to fetch exposure records. If not provided,
1149-
exposure dimension records will not be included in the table.
1187+
Butler instance used to fetch dimension records.
1188+
add_dimension_records : `bool`, optional
1189+
If `True`, include visit and exposure dimension records in the
1190+
table, if available. This requires ``butler`` to be provided.
11501191
add_exception_msg : `bool`, optional
11511192
If `True`, include the exception message in the table.
11521193
max_message_width : `int`, optional
@@ -1165,31 +1206,32 @@ def make_exception_diagnostics_table(
11651206
task), and optionally, exposure dimension records and exception
11661207
messages.
11671208
"""
1168-
add_exposure_records = True
1169-
needed_exposure_records = ["day_obs", "physical_filter", "exposure_time", "target_name"]
1170-
1171-
# Preload all exposure dimension records up front for faster O(1)
1172-
# lookup later. Querying per data ID in the loop is painfully slow.
1173-
if butler:
1174-
exposure_record_lookup = {
1175-
d.dataId["exposure"]: d for d in butler.query_dimension_records("exposure", explain=False)
1176-
}
1177-
else:
1178-
exposure_record_lookup = {}
1179-
add_exposure_records = False
1209+
if add_dimension_records and butler is None:
1210+
raise ValueError("Butler is required to fetch dimension records.")
11801211

1181-
if butler and not exposure_record_lookup:
1182-
_LOG.warning("No exposure records found in the butler; they will not be included in the table.")
1183-
add_exposure_records = False
1212+
# The additional columns for visit and exposure records to add to the
1213+
# output table, if available. Note that 'band', 'day_obs', and
1214+
# 'physical_filter' already exist in `exception.data_id` below.
1215+
needed_visit_records = ["exposure_time", "target_name"]
1216+
needed_exposure_records = ["exposure_time", "target_name"]
11841217

11851218
rows: defaultdict[tuple, defaultdict[str, str]] = defaultdict(lambda: defaultdict(str))
1219+
exposure_data_ids: list[dict] = []
1220+
visit_data_ids: list[dict] = []
1221+
dimension_record_lookup: dict[str, DimensionRecordQueryResults] = {}
11861222

11871223
# Loop over all tasks and exceptions, and associate them with data IDs.
11881224
for task_label, task_summary in self.tasks.items():
11891225
for type_name, exceptions in task_summary.exceptions.items():
11901226
for exception in exceptions:
1191-
data_id = exception.data_id
1192-
key = tuple(sorted(data_id.items())) # Hashable and stable
1227+
data_id = DataCoordinate.standardize(exception.data_id, universe=butler.dimensions)
1228+
if add_dimension_records:
1229+
if "visit" in data_id:
1230+
visit_data_ids.append(data_id)
1231+
elif "exposure" in data_id:
1232+
exposure_data_ids.append(data_id)
1233+
# Define a hashable and stable tuple of data ID values.
1234+
key = tuple(sorted(data_id.mapping.items()))
11931235
assert len(rows[key]) == 0, f"Multiple exceptions for one data ID: {key}"
11941236
assert rows[key]["Exception"] == "", f"Duplicate entry for data ID {key} in {task_label}"
11951237
if shorten_type_name:
@@ -1203,32 +1245,59 @@ def make_exception_diagnostics_table(
12031245
if max_message_width and len(msg) > max_message_width:
12041246
msg = textwrap.shorten(msg, max_message_width)
12051247
rows[key]["Exception Message"] = msg
1206-
if add_exposure_records:
1207-
exposure_record = exposure_record_lookup[data_id["exposure"]]
1208-
for k in needed_exposure_records:
1209-
rows[key][k] = getattr(exposure_record, k)
12101248

1211-
# Extract all unique columns.
1212-
all_columns = {col for r in rows.values() for col in r}
1213-
table_rows = []
1249+
if add_dimension_records and (visit_data_ids or exposure_data_ids):
1250+
# Preload all the dimension records up front for faster O(1) lookup
1251+
# later. Querying per data ID in the loop is painfully slow. These
1252+
# data IDs are limited to the ones that have exceptions.
1253+
with butler.query() as query:
1254+
query = query.join_data_coordinates(visit_data_ids + exposure_data_ids)
1255+
for element in ["visit", "exposure"]:
1256+
dimension_record_lookup |= {
1257+
f"{element}:{d.dataId[element]}": d for d in query.dimension_records(element)
1258+
}
1259+
1260+
# Loop over the data IDs and fill in the dimension records.
1261+
for element, data_ids, needed_records in zip(
1262+
["visit", "exposure"],
1263+
[visit_data_ids, exposure_data_ids],
1264+
[needed_visit_records, needed_exposure_records],
1265+
):
1266+
for data_id in data_ids:
1267+
key = tuple(sorted(data_id.mapping.items()))
1268+
for k in needed_records:
1269+
rows[key][k] = getattr(
1270+
dimension_record_lookup[f"{element}:{dataId[element]}"], k, None
1271+
)
1272+
1273+
# Extract all unique data ID keys from the rows for the table header.
1274+
all_key_columns = {k for key in rows for k, _ in key}
12141275

12151276
# Loop over all rows and add them to the table.
1216-
for key, col_counts in rows.items():
1217-
# Add data ID values as columns at the start of the row.
1218-
row = dict(key)
1219-
# Add exposure records next, if requested.
1220-
if add_exposure_records:
1221-
for col in needed_exposure_records:
1222-
row[col] = col_counts.get(col, "-")
1223-
# Add all other columns last.
1224-
for col in all_columns - set(needed_exposure_records) - {"Exception Message"}:
1225-
row[col] = col_counts.get(col, "-")
1226-
# Add the exception message if requested.
1277+
table_rows = []
1278+
for key, values in rows.items():
1279+
# Create a new row with all key columns initialized to None,
1280+
# allowing missing values to be properly masked when `masked=True`.
1281+
row = {col: None for col in all_key_columns}
1282+
# Fill in data ID fields from the key.
1283+
row.update(dict(key))
1284+
# Add dimension records next, if requested and available.
1285+
if add_dimension_records:
1286+
if visit_data_ids:
1287+
for col in needed_visit_records:
1288+
row[col] = values.get(col, None)
1289+
if exposure_data_ids:
1290+
for col in needed_exposure_records:
1291+
row[col] = values.get(col, None)
1292+
# Add task label and exception type.
1293+
for col in ("Task", "Exception"):
1294+
row[col] = values.get(col, None)
1295+
# Add the exception message, if requested.
12271296
if add_exception_msg:
1228-
row["Exception Message"] = col_counts.get("Exception Message", "-")
1297+
row["Exception Message"] = values.get("Exception Message", None)
12291298
table_rows.append(row)
12301299

1231-
return astropy.table.Table(table_rows)
1300+
return astropy.table.Table(table_rows, masked=True)
12321301

12331302
def make_bad_quantum_tables(self, max_message_width: int = 80) -> dict[str, astropy.table.Table]:
12341303
"""Construct an `astropy.table.Table` with per-data-ID information

0 commit comments

Comments
 (0)