|
2 | 2 | import json |
3 | 3 | from datetime import datetime |
4 | 4 | from enum import Flag, auto |
| 5 | +from pathlib import Path |
5 | 6 |
|
| 7 | +import matplotlib.dates as mdates |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import matplotlib.ticker as mticker |
6 | 10 | import numpy as np |
7 | 11 | import pandas as pd |
8 | 12 | import scipy |
|
17 | 21 |
|
18 | 22 | from ml.inference import load_model |
19 | 23 | from ml.types import InferenceModel |
| 24 | +from shared.color import COLOR_PALETTE |
20 | 25 | from shared.lakehouse import Lakehouse |
| 26 | +from shared.settings import LOCAL_DIR |
21 | 27 |
|
22 | 28 |
|
23 | 29 | class MonitoringStats(Flag): |
@@ -411,4 +417,58 @@ def load(self): |
411 | 417 | self.stats = self.lh.ml_monitoring_load(self.schema) |
412 | 418 |
|
413 | 419 | def plot(self): |
414 | | - pass |
| 420 | + output_dir = Path(LOCAL_DIR) / "monitor" |
| 421 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 422 | + |
| 423 | + data = self.stats.copy() |
| 424 | + |
| 425 | + data["model"] = data.model_name + " (" + data.model_version + ")" |
| 426 | + data = data.drop(columns=["model_name", "model_version"]) |
| 427 | + |
| 428 | + metric_names = { |
| 429 | + "count": "Number of Inferences Over Time", |
| 430 | + "pred_drift": "Prediction Shift (KS D-Statistic)", |
| 431 | + "feat_drift": "Feature Drift (ROC AUC)", |
| 432 | + "e_f1": "Estimated F1-Score Based on CBPE", |
| 433 | + "e_accuracy": "Estimated Accuracy Based on CBPE", |
| 434 | + "user_brier": "Mean Brier Score Based on Avg. User Feedback", |
| 435 | + } |
| 436 | + |
| 437 | + metrics = list(set(data.columns) & set(metric_names.keys())) |
| 438 | + |
| 439 | + data = data[["date", "model"] + metrics].pivot( |
| 440 | + index="date", |
| 441 | + columns="model", |
| 442 | + values=metrics, |
| 443 | + ) |
| 444 | + |
| 445 | + plt.rcParams["axes.prop_cycle"] = plt.cycler(color=COLOR_PALETTE) |
| 446 | + |
| 447 | + for metric in metrics: |
| 448 | + output_path = output_dir / f"{metric}.png" |
| 449 | + metric_name = metric_names[metric] |
| 450 | + |
| 451 | + log.info("Plotting {} into {}", metric_name, output_path) |
| 452 | + |
| 453 | + fig, ax = plt.subplots(figsize=(7, 3.5), dpi=300) |
| 454 | + |
| 455 | + data[metric].plot.bar(ax=ax, rot=0, xlabel="") |
| 456 | + |
| 457 | + step = 7 |
| 458 | + ax.set_xticks(range(0, len(data[metric]), step)) |
| 459 | + ax.set_xticklabels(data[metric].index[::step].strftime("%d %b %Y")) |
| 460 | + |
| 461 | + plt.xticks(ha="left") |
| 462 | + |
| 463 | + plt.legend( |
| 464 | + title=None, |
| 465 | + ncol=2, |
| 466 | + loc="upper left", |
| 467 | + bbox_to_anchor=(0, -0.175), |
| 468 | + borderaxespad=0, |
| 469 | + ) |
| 470 | + |
| 471 | + plt.title(metric_name) |
| 472 | + plt.tight_layout() |
| 473 | + |
| 474 | + fig.savefig(output_path) |
0 commit comments