diff --git a/configs/experiment/ml/final_linear_provgigapath_adamw.yaml b/configs/experiment/ml/final_linear_provgigapath_adamw.yaml new file mode 100644 index 0000000..2de2584 --- /dev/null +++ b/configs/experiment/ml/final_linear_provgigapath_adamw.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +defaults: + - /experiment/ml/final_linear_virchow2_adamw + - _self_ + +embedding_model_name: ProvGigaPath +embedding_dim: 1536 +embedding_run_id: 410c8672471348ceb4c58817f70fa097 +kfold_strategy: stratified_group +kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} +mlflow_artifact_path: linear_classifier_final_provgigapath + +# Set after Stage 1 from ProvGigaPath's own AdamW sweep selected by +# validation/f1_macro. +model: + weight_decay: 1.0e-4 + +metadata: + run_name: Final Linear Classifier AdamW ProvGigaPath ${dataset.name} + description: "Final AdamW linear probe over frozen ProvGigaPath embeddings, trained on all training folds with the ProvGigaPath-selected weight decay." diff --git a/configs/experiment/ml/final_linear_provgigapath_lbfgs.yaml b/configs/experiment/ml/final_linear_provgigapath_lbfgs.yaml new file mode 100644 index 0000000..49c71f3 --- /dev/null +++ b/configs/experiment/ml/final_linear_provgigapath_lbfgs.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +defaults: + - /experiment/ml/final_linear_virchow2_lbfgs + - _self_ + +embedding_model_name: ProvGigaPath +embedding_dim: 1536 +embedding_run_id: 410c8672471348ceb4c58817f70fa097 +kfold_strategy: stratified_group +kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} +mlflow_artifact_path: linear_classifier_final_provgigapath + +# Set after Stage 1 from ProvGigaPath's own LBFGS sweep selected by +# validation/f1_macro. +model: + weight_decay: 1.0e-4 + +metadata: + run_name: Final Linear Classifier LBFGS ProvGigaPath ${dataset.name} + description: "Final LBFGS linear probe over frozen ProvGigaPath embeddings, exact full-batch solve with the ProvGigaPath-selected weight decay." diff --git a/configs/experiment/ml/linear_classifier_final_adamw.yaml b/configs/experiment/ml/final_linear_virchow2_adamw.yaml similarity index 100% rename from configs/experiment/ml/linear_classifier_final_adamw.yaml rename to configs/experiment/ml/final_linear_virchow2_adamw.yaml diff --git a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml b/configs/experiment/ml/final_linear_virchow2_lbfgs.yaml similarity index 100% rename from configs/experiment/ml/linear_classifier_final_lbfgs.yaml rename to configs/experiment/ml/final_linear_virchow2_lbfgs.yaml diff --git a/configs/experiment/ml/linear_classifier_adamw_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_adamw_stratified_kfold.yaml deleted file mode 100644 index 5d09c02..0000000 --- a/configs/experiment/ml/linear_classifier_adamw_stratified_kfold.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# @package _global_ - -defaults: - - /ml/task: kfold_linear_classifier - - _self_ - -kfold_strategy: stratified -kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} - -model: - optimizer: adamw - learning_rate: 1.0e-4 - weight_decay: 0.0 diff --git a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml b/configs/experiment/ml/predict_linear_virchow2_lbfgs_tissue_tiles.yaml similarity index 100% rename from configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml rename to configs/experiment/ml/predict_linear_virchow2_lbfgs_tissue_tiles.yaml diff --git a/configs/experiment/ml/test_linear_provgigapath_adamw.yaml b/configs/experiment/ml/test_linear_provgigapath_adamw.yaml new file mode 100644 index 0000000..4d94a94 --- /dev/null +++ b/configs/experiment/ml/test_linear_provgigapath_adamw.yaml @@ -0,0 +1,16 @@ +# @package _global_ + +defaults: + - /experiment/ml/final_linear_provgigapath_adamw + - _self_ + +# Held-out test for the final ProvGigaPath AdamW checkpoint. Uses the same +# filtered labeled test split, thresholds, metrics, and checkpoint convention as +# the Virchow2 test config. +mode: test +final_train_run_id: fe172ccd8c1140269f7f3d1fdbd351ea +checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt +checkpoint_weights_only: false + +data: + num_workers: 0 diff --git a/configs/experiment/ml/test_linear_provgigapath_lbfgs.yaml b/configs/experiment/ml/test_linear_provgigapath_lbfgs.yaml new file mode 100644 index 0000000..bb8ded9 --- /dev/null +++ b/configs/experiment/ml/test_linear_provgigapath_lbfgs.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - /experiment/ml/final_linear_provgigapath_lbfgs + - override /ml/trainer: early_stopping + - _self_ + +# Held-out test for the final ProvGigaPath LBFGS checkpoint. Uses the same +# filtered labeled test split, thresholds, metrics, and checkpoint convention as +# the Virchow2 test config. +mode: test +final_train_run_id: 067b08dcbdb54d9187fbd4dd8d5599a1 +checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt +checkpoint_weights_only: false + +data: + train_batch_size: 1024 + num_workers: 0 diff --git a/configs/experiment/ml/linear_classifier_test_adamw.yaml b/configs/experiment/ml/test_linear_virchow2_adamw.yaml similarity index 69% rename from configs/experiment/ml/linear_classifier_test_adamw.yaml rename to configs/experiment/ml/test_linear_virchow2_adamw.yaml index ed214e5..9e3c398 100644 --- a/configs/experiment/ml/linear_classifier_test_adamw.yaml +++ b/configs/experiment/ml/test_linear_virchow2_adamw.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - /experiment/ml/linear_classifier_final_adamw + - /experiment/ml/final_linear_virchow2_adamw - _self_ # Test the AdamW final checkpoint on the held-out test split. Same model @@ -20,13 +20,3 @@ checkpoint_weights_only: false # before the first test batch. final_embedding_tiles defaults to 4; override here. data: num_workers: 0 - -trainer: - callbacks: - tiff_prediction_maps: - _target_: ml.callbacks.TiffPredictionMapWriter - slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet - artifact_path: prediction_maps_tiff - draw_region: central_stride - slide_selection: all - max_slides: null diff --git a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml b/configs/experiment/ml/test_linear_virchow2_lbfgs.yaml similarity index 96% rename from configs/experiment/ml/linear_classifier_test_lbfgs.yaml rename to configs/experiment/ml/test_linear_virchow2_lbfgs.yaml index c005c37..0d60bd1 100644 --- a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml +++ b/configs/experiment/ml/test_linear_virchow2_lbfgs.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - /experiment/ml/linear_classifier_final_lbfgs + - /experiment/ml/final_linear_virchow2_lbfgs - override /ml/trainer: early_stopping - _self_ diff --git a/configs/experiment/ml/train_linear_provgigapath_adamw_group_kfold.yaml b/configs/experiment/ml/train_linear_provgigapath_adamw_group_kfold.yaml new file mode 100644 index 0000000..797f6e0 --- /dev/null +++ b/configs/experiment/ml/train_linear_provgigapath_adamw_group_kfold.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +defaults: + - /experiment/ml/train_linear_virchow2_adamw_group_kfold + - _self_ + +embedding_model_name: ProvGigaPath +embedding_dim: 1536 +embedding_run_id: 410c8672471348ceb4c58817f70fa097 +mlflow_artifact_path: linear_classifier_provgigapath + +metadata: + run_name: Linear Classifier ProvGigaPath ${dataset.name} ${kfold_strategy} fold=${val_fold} opt=${model.optimizer} wd=${model.weight_decay} + description: "Linear probe over frozen ProvGigaPath embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml b/configs/experiment/ml/train_linear_provgigapath_lbfgs_group_kfold.yaml similarity index 83% rename from configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml rename to configs/experiment/ml/train_linear_provgigapath_lbfgs_group_kfold.yaml index f857ccd..bdfaff0 100644 --- a/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml +++ b/configs/experiment/ml/train_linear_provgigapath_lbfgs_group_kfold.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - /experiment/ml/linear_classifier_stratified_kfold + - /experiment/ml/train_linear_provgigapath_adamw_group_kfold - _self_ trainer: @@ -11,6 +11,7 @@ data: train_batch_size: 1000000000 train_shuffle: false train_drop_last: false + num_workers: 0 model: optimizer: lbfgs diff --git a/configs/experiment/ml/linear_classifier_adamw_stratified_group_kfold.yaml b/configs/experiment/ml/train_linear_virchow2_adamw_group_kfold.yaml similarity index 100% rename from configs/experiment/ml/linear_classifier_adamw_stratified_group_kfold.yaml rename to configs/experiment/ml/train_linear_virchow2_adamw_group_kfold.yaml diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml b/configs/experiment/ml/train_linear_virchow2_lbfgs_group_kfold.yaml similarity index 87% rename from configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml rename to configs/experiment/ml/train_linear_virchow2_lbfgs_group_kfold.yaml index 3f11835..809f522 100644 --- a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml +++ b/configs/experiment/ml/train_linear_virchow2_lbfgs_group_kfold.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - /experiment/ml/linear_classifier_stratified_group_kfold + - /experiment/ml/train_linear_virchow2_adamw_group_kfold - _self_ trainer: diff --git a/configs/experiment/preprocessing/embeddings_prov_gigapath_05mpp.yaml b/configs/experiment/preprocessing/embeddings_provgigapath_0_5mpp.yaml similarity index 100% rename from configs/experiment/preprocessing/embeddings_prov_gigapath_05mpp.yaml rename to configs/experiment/preprocessing/embeddings_provgigapath_0_5mpp.yaml diff --git a/configs/experiment/preprocessing/embeddings_virchow2_05mpp.yaml b/configs/experiment/preprocessing/embeddings_virchow2_0_5mpp.yaml similarity index 100% rename from configs/experiment/preprocessing/embeddings_virchow2_05mpp.yaml rename to configs/experiment/preprocessing/embeddings_virchow2_0_5mpp.yaml diff --git a/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_05mpp.yaml b/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_05mpp.yaml deleted file mode 100644 index 4ea45ff..0000000 --- a/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_05mpp.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# @package _global_ - -defaults: - - /experiment/preprocessing/embeddings_virchow2_05mpp - - _self_ - -# Embeddings for every tile in the train/test tiling split that intersects the -# tissue mask. This is the tile universe used for doctor-review prediction maps. -splits: - - test -tile_source_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id} -tile_source_artifact_template: "tissue_stats/{split}_tiles.parquet" -tile_filter_column: tile_tissue_coverage - -metadata: - run_name: "Embeddings: ${model} tissue tiles" - description: "Tile embeddings using ${model} over held-out test split tiles with tile_tissue_coverage > 0." diff --git a/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_0_5mpp.yaml b/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_0_5mpp.yaml new file mode 100644 index 0000000..7b9f54f --- /dev/null +++ b/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_0_5mpp.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +defaults: + - /experiment/preprocessing/embeddings_virchow2_0_5mpp + - _self_ + +# Embeddings for a deterministic sampled subset of test slides whose tiles +# intersect the tissue mask. The sample is capped by slide_sample_max_tiles and +# selected with slide_sample_seed for doctor-review prediction maps. +splits: + - test +tile_source_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id} +tile_source_artifact_template: "tissue_stats/{split}_tiles.parquet" +tile_filter_column: tile_tissue_coverage +slide_sample_max_tiles: 2000000 +slide_sample_seed: 0 + +metadata: + run_name: "Embeddings: ${model} tissue tiles" + description: "Tile embeddings using ${model} over a sampled held-out test slide subset with tile_tissue_coverage > 0, capped by slide_sample_max_tiles=${slide_sample_max_tiles} and selected with slide_sample_seed=${slide_sample_seed}." diff --git a/configs/experiment/preprocessing/tile_masks_05mpp.yaml b/configs/experiment/preprocessing/tile_masks_0_5mpp.yaml similarity index 100% rename from configs/experiment/preprocessing/tile_masks_05mpp.yaml rename to configs/experiment/preprocessing/tile_masks_0_5mpp.yaml diff --git a/configs/experiment/preprocessing/tiling_05mpp.yaml b/configs/experiment/preprocessing/tiling_0_5mpp.yaml similarity index 100% rename from configs/experiment/preprocessing/tiling_05mpp.yaml rename to configs/experiment/preprocessing/tiling_0_5mpp.yaml diff --git a/configs/experiment/preprocessing/tissue_masks_mpp2.yaml b/configs/experiment/preprocessing/tissue_masks_2mpp.yaml similarity index 100% rename from configs/experiment/preprocessing/tissue_masks_mpp2.yaml rename to configs/experiment/preprocessing/tissue_masks_2mpp.yaml diff --git a/configs/experiment/preprocessing/tissue_stats_05mpp.yaml b/configs/experiment/preprocessing/tissue_stats_0_5mpp.yaml similarity index 100% rename from configs/experiment/preprocessing/tissue_stats_05mpp.yaml rename to configs/experiment/preprocessing/tissue_stats_0_5mpp.yaml diff --git a/configs/ml/data/final_embedding_tiles.yaml b/configs/ml/data/final_embedding_tiles.yaml index 826f7a2..c7ffa53 100644 --- a/configs/ml/data/final_embedding_tiles.yaml +++ b/configs/ml/data/final_embedding_tiles.yaml @@ -21,3 +21,4 @@ data: class_indices: ${class_indices} thresholds: ${thresholds} tissue_prop_min: ${tissue_prop_min} + slide_metadata_uri: ${test_slide_metadata_uri} diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index 2760ab7..18b29fd 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -6,7 +6,7 @@ model: decode_head: _target_: torch.nn.Linear - in_features: 2560 + in_features: ${embedding_dim} out_features: ${len:${class_indices}} class_indices: ${class_indices} diff --git a/configs/ml/task/final_linear_classifier.yaml b/configs/ml/task/final_linear_classifier.yaml index 61afef7..5e8f3b3 100644 --- a/configs/ml/task/final_linear_classifier.yaml +++ b/configs/ml/task/final_linear_classifier.yaml @@ -10,6 +10,8 @@ defaults: mode: fit +embedding_model_name: Virchow2 +embedding_dim: 2560 embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} kfold_strategy: stratified kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} @@ -19,6 +21,7 @@ train_embedding_uri: runs:/${embedding_run_id}/train/tiles test_embedding_uri: runs:/${embedding_run_id}/test/tiles train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet +test_slide_metadata_uri: runs:/${embedding_run_id}/test/slides.parquet tissue_prop_min: 0.2 thresholds: @@ -34,8 +37,10 @@ mlflow_artifact_path: linear_classifier_final metadata: run_name: Final Linear Classifier ${dataset.name} - description: "Final linear probe over frozen Virchow2 embeddings trained on all training folds for ${trainer.max_epochs} epochs." + description: "Final linear probe over frozen ${embedding_model_name} embeddings trained on all training folds for ${trainer.max_epochs} epochs." hyperparams: + embedding_model_name: ${embedding_model_name} + embedding_dim: ${embedding_dim} embedding_run_id: ${embedding_run_id} kfold_strategy: ${kfold_strategy} kfold_run_id: ${kfold_run_id} diff --git a/configs/ml/task/kfold_linear_classifier.yaml b/configs/ml/task/kfold_linear_classifier.yaml index aa73341..6b6951b 100644 --- a/configs/ml/task/kfold_linear_classifier.yaml +++ b/configs/ml/task/kfold_linear_classifier.yaml @@ -10,6 +10,8 @@ defaults: mode: fit +embedding_model_name: Virchow2 +embedding_dim: 2560 embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} kfold_strategy: stratified kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} @@ -36,8 +38,10 @@ mlflow_artifact_path: linear_classifier metadata: run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} opt=${model.optimizer} wd=${model.weight_decay} - description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." + description: "Linear probe over frozen ${embedding_model_name} embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." hyperparams: + embedding_model_name: ${embedding_model_name} + embedding_dim: ${embedding_dim} embedding_run_id: ${embedding_run_id} kfold_strategy: ${kfold_strategy} kfold_run_id: ${kfold_run_id} diff --git a/configs/ml/trainer/early_stopping.yaml b/configs/ml/trainer/early_stopping.yaml index 8615025..0e35f84 100644 --- a/configs/ml/trainer/early_stopping.yaml +++ b/configs/ml/trainer/early_stopping.yaml @@ -19,6 +19,7 @@ trainer: _target_: lightning.pytorch.callbacks.ModelCheckpoint monitor: train/loss_epoch mode: min + save_last: true save_top_k: 1 filename: "epoch={epoch}-train_loss={train/loss_epoch:.4f}" auto_insert_metric_name: false diff --git a/configs/preprocessing/embeddings.yaml b/configs/preprocessing/embeddings.yaml index 91400ac..1b98f22 100644 --- a/configs/preprocessing/embeddings.yaml +++ b/configs/preprocessing/embeddings.yaml @@ -11,6 +11,8 @@ splits: tile_source_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} tile_source_artifact_template: "filter_tiles/{split}_tiles.parquet" tile_filter_column: null +slide_sample_max_tiles: null +slide_sample_seed: 0 metadata: run_name: "Embeddings: ${model}" @@ -23,3 +25,5 @@ metadata: tile_source_run_id: ${tile_source_run_id} tile_source_artifact_template: ${tile_source_artifact_template} tile_filter_column: ${tile_filter_column} + slide_sample_max_tiles: ${slide_sample_max_tiles} + slide_sample_seed: ${slide_sample_seed} diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index e8548db..e6a00b3 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -1,6 +1,7 @@ """Write tile predictions as WSI-aligned BigTIFF masks.""" from collections.abc import Mapping +from hashlib import blake2b from pathlib import Path from re import sub from tempfile import TemporaryDirectory @@ -517,7 +518,10 @@ def _safe_filename(value: str) -> str: def _slide_prediction_filename(path: str | Path) -> str: - return Path(str(path)).with_suffix(".tiff").name + path_str = str(path) + stem = Path(path_str).stem + digest = blake2b(path_str.encode("utf-8"), digest_size=4).hexdigest() + return _safe_filename(f"{stem}-{digest}.tiff") def _spread_lut(n_classes: int) -> np.ndarray: diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index ba00eb0..8d1b690 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -27,6 +27,7 @@ def __init__( embedding_uri: str | Path, meta_df: pd.DataFrame, diag: Callable[[str], None], + slide_metadata_uri: str | Path | None = None, ) -> None: diag(f"metadata filtered: {len(meta_df)} rows; reading embeddings") joined_keys, embeddings = _load_embeddings_and_join( @@ -37,6 +38,9 @@ def __init__( self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() self.xs = joined_keys.column("x").to_pandas().to_numpy(dtype=np.int64) self.ys = joined_keys.column("y").to_pandas().to_numpy(dtype=np.int64) + self.slide_names_by_id = ( + _load_slide_names(slide_metadata_uri) if slide_metadata_uri else {} + ) diag(f"dataset ready: {len(self.labels)} samples, dim={embeddings.shape[1]}") def __len__(self) -> int: @@ -78,6 +82,7 @@ def __init__( tissue_prop_min: float, include_folds: list[int] | None = None, exclude_folds: list[int] | None = None, + slide_metadata_uri: str | Path | None = None, ) -> None: self.class_indices = class_indices diag = _make_diag(type(self).__name__) @@ -89,7 +94,7 @@ def __init__( include_folds, exclude_folds, ) - super().__init__(embedding_uri, meta_df, diag) + super().__init__(embedding_uri, meta_df, diag, slide_metadata_uri) def _labels_from_joined_keys(self, joined_keys: pa.Table) -> np.ndarray: labels = joined_keys.column("label").to_pandas() @@ -233,12 +238,13 @@ def __init__( tissue_column: str = "tile_tissue_coverage", tissue_min: float = 0.0, label_value: int = -1, + slide_metadata_uri: str | Path | None = None, ) -> None: self.label_value = label_value diag = _make_diag(type(self).__name__) diag("filtering metadata") meta_df = self._filter_metadata(metadata_uri, tissue_column, tissue_min) - super().__init__(embedding_uri, meta_df, diag) + super().__init__(embedding_uri, meta_df, diag, slide_metadata_uri) def _labels_from_joined_keys(self, joined_keys: pa.Table) -> np.ndarray: return np.full(joined_keys.num_rows, self.label_value, dtype=np.int64) @@ -268,6 +274,12 @@ def _resolve_uri(path_or_uri: str | Path) -> str: return _resolve_uri_cached(str(path_or_uri)) +def _load_slide_names(slide_metadata_uri: str | Path) -> dict[str, str]: + local = _resolve_uri(slide_metadata_uri) + df = pd.read_parquet(local, columns=["id", "path"]) + return {str(row.id): Path(str(row.path)).name for row in df.itertuples(index=False)} + + def _make_diag(dataset_name: str) -> Callable[[str], None]: t0 = perf_counter() diff --git a/ml/meta_arch.py b/ml/meta_arch.py index e1baae3..f490607 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -357,19 +357,28 @@ def _log_per_slide_accuracy(self) -> None: self.log("test/slide_acc_median", float(np.median(accs)), on_epoch=True) self.log("test/slide_acc_min", float(np.min(accs)), on_epoch=True) - rows = [ - { - "slide_id": s, - "tile_accuracy": self._test_slide_correct[s] / n, - "n_tiles": n, - } - for s, n in self._test_slide_total.items() - ] + slide_names = self._test_slide_names_by_id() + rows = [] + for s, n in self._test_slide_total.items(): + row: dict[str, Any] = {"slide_id": s} + if slide_names: + row["slide_name"] = slide_names.get(s) + row["tile_accuracy"] = self._test_slide_correct[s] / n + row["n_tiles"] = n + rows.append(row) mlflow.log_table( data=pd.DataFrame(rows), artifact_file="per_slide/test_tile_accuracy.json", ) + def _test_slide_names_by_id(self) -> dict[str, str]: + datamodule = getattr(self.trainer, "datamodule", None) + dataset = getattr(datamodule, "test", None) + slide_names = getattr(dataset, "slide_names_by_id", {}) + if isinstance(slide_names, dict): + return slide_names + return {} + def _confmat_figure( matrix: np.ndarray, class_names: Iterable[str], title: str diff --git a/preprocessing/embeddings.py b/preprocessing/embeddings.py index 7ad74dd..fe9c202 100644 --- a/preprocessing/embeddings.py +++ b/preprocessing/embeddings.py @@ -7,7 +7,9 @@ import httpx import hydra import mlflow.artifacts +import numpy as np import pandas as pd +import pyarrow as pa import pyarrow.dataset as pads import ray from omegaconf import DictConfig @@ -51,6 +53,63 @@ async def __call__(self, row: dict[str, Any]) -> dict[str, Any]: return row +def select_slide_budget( + tiles_dataset: pads.Dataset, + row_filter: pads.Expression | None, + slide_order: pd.Series, + max_tiles: int, + seed: int, +) -> tuple[set[str], int]: + """Select a deterministic random slide subset within a tile budget.""" + if max_tiles <= 0: + raise ValueError(f"slide_sample_max_tiles must be positive, got {max_tiles}") + + slide_ids = tiles_dataset.to_table( + columns=["slide_id"], + filter=row_filter, + ).column("slide_id") + counts = ( + pa.table({"slide_id": slide_ids.value_counts()}) + .flatten() + .to_pandas() + .rename( + columns={ + "slide_id.values": "slide_id", + "slide_id.counts": "tile_count", + } + ) + ) + if counts.empty: + raise ValueError("No tiles available after applying the embedding tile filter") + counts["slide_id"] = counts["slide_id"].astype(str) + + ordered_ids = slide_order.astype(str).tolist() + rank = {slide_id: index for index, slide_id in enumerate(ordered_ids)} + counts["_rank"] = counts["slide_id"].map(rank) + counts = counts.sort_values("_rank").drop(columns="_rank") + + rng = np.random.default_rng(seed) + shuffled = counts.iloc[rng.permutation(len(counts))] + + selected: list[str] = [] + selected_tiles = 0 + for row in shuffled.itertuples(index=False): + tile_count = int(row.tile_count) + if selected_tiles + tile_count > max_tiles: + continue + selected.append(str(row.slide_id)) + selected_tiles += tile_count + if selected_tiles >= max_tiles: + break + + if not selected: + smallest = counts.sort_values("tile_count").iloc[0] + selected = [str(smallest["slide_id"])] + selected_tiles = int(smallest["tile_count"]) + + return set(selected), selected_tiles + + @with_cli_args(["+preprocessing=embeddings"]) @hydra.main(config_path="../configs", config_name="preprocessing", version_base=None) @autolog @@ -62,6 +121,8 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: "tile_source_artifact_template", "filter_tiles/{split}_tiles.parquet" ) tile_filter_column = config.get("tile_filter_column") + slide_sample_max_tiles = config.get("slide_sample_max_tiles") + slide_sample_seed = int(config.get("slide_sample_seed", 0)) for name in config.get("splits", ["train", "test"]): split_folder = Path( @@ -86,9 +147,30 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: if tile_filter_column is not None else None ) - num_rows = pads.dataset(str(tiles_path), format="parquet").count_rows( - filter=row_filter - ) + tiles_dataset = pads.dataset(str(tiles_path), format="parquet") + num_rows = tiles_dataset.count_rows(filter=row_filter) + selected_slide_ids: set[str] | None = None + if slide_sample_max_tiles is not None: + selected_slide_ids, num_rows = select_slide_budget( + tiles_dataset=tiles_dataset, + row_filter=row_filter, + slide_order=slides["id"], + max_tiles=int(slide_sample_max_tiles), + seed=slide_sample_seed, + ) + slides = slides[slides["id"].astype(str).isin(selected_slide_ids)].copy() + slide_info = { + slide_id: info + for slide_id, info in slide_info.items() + if str(slide_id) in selected_slide_ids + } + print( + f"[main] split={name} selected {len(selected_slide_ids)} slides " + f"with {num_rows} tiles " + f"(slide_sample_max_tiles={slide_sample_max_tiles}, " + f"slide_sample_seed={slide_sample_seed})", + flush=True, + ) num_blocks = max(1, num_rows // config.block_size) columns = ["slide_id", "x", "y"] @@ -105,6 +187,11 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: lambda row, c: row[c] > 0, fn_kwargs={"c": tile_filter_column} ) ds = ds.drop_columns([tile_filter_column]) + if selected_slide_ids is not None: + ds = ds.filter( + lambda row, ids: str(row["slide_id"]) in ids, + fn_kwargs={"ids": selected_slide_ids}, + ) ds = ds.map( lambda row, si: {**row, **si[row["slide_id"]]}, fn_kwargs={"si": slide_info},