diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 09f8f4a..f061371 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,7 +14,7 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" - stratified_kfold_run_id: "c7eafdffa32743aa9eb6dd2bf3a185b5" + stratified_kfold_run_id: "850c81506684450b9af92296acfd045a" stratified_group_kfold_run_id: "382b41d2fa894514908e8067949c4326" embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" diff --git a/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_adamw_stratified_group_kfold.yaml similarity index 57% rename from configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml rename to configs/experiment/ml/linear_classifier_adamw_stratified_group_kfold.yaml index 471f5a3..86f5647 100644 --- a/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml +++ b/configs/experiment/ml/linear_classifier_adamw_stratified_group_kfold.yaml @@ -1,8 +1,13 @@ # @package _global_ defaults: - - /ml/linear_classifier + - /ml/task: kfold_linear_classifier - _self_ kfold_strategy: stratified_group kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} + +model: + optimizer: adamw + learning_rate: 1.0e-4 + weight_decay: 0.0 diff --git a/configs/experiment/ml/linear_classifier_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_adamw_stratified_kfold.yaml similarity index 55% rename from configs/experiment/ml/linear_classifier_stratified_kfold.yaml rename to configs/experiment/ml/linear_classifier_adamw_stratified_kfold.yaml index c01fbbf..5d09c02 100644 --- a/configs/experiment/ml/linear_classifier_stratified_kfold.yaml +++ b/configs/experiment/ml/linear_classifier_adamw_stratified_kfold.yaml @@ -1,8 +1,13 @@ # @package _global_ defaults: - - /ml/linear_classifier + - /ml/task: kfold_linear_classifier - _self_ kfold_strategy: stratified kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} + +model: + optimizer: adamw + learning_rate: 1.0e-4 + weight_decay: 0.0 diff --git a/configs/experiment/ml/linear_classifier_final_adamw.yaml b/configs/experiment/ml/linear_classifier_final_adamw.yaml new file mode 100644 index 0000000..9cc037f --- /dev/null +++ b/configs/experiment/ml/linear_classifier_final_adamw.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +defaults: + - /ml/task: final_linear_classifier + - override /ml/trainer: early_stopping + - _self_ + +# AdamW final: trained to convergence with the same early-stopping rule as the +# k-fold sweep (monitor train/loss_epoch, patience 1, min_delta 1e-4), not a +# fixed 6-epoch budget. weight_decay=1e-3 = best AdamW sweep point (flat curve). +model: + optimizer: adamw + learning_rate: 1.0e-4 + weight_decay: 1.0e-3 + +trainer: + callbacks: + model_checkpoint: + save_last: true + +metadata: + run_name: Final Linear Classifier AdamW ${dataset.name} + description: "Final AdamW linear probe over frozen Virchow2 embeddings, trained on all training folds with early stopping on train/loss_epoch." diff --git a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml new file mode 100644 index 0000000..5169d83 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +defaults: + - /ml/task: final_linear_classifier + - _self_ + +# LBFGS final: exact solve of the convex objective on the full training batch. +# weight_decay=1e-2 = best LBFGS sweep point. Full-batch guard requires +# train_shuffle=false, train_drop_last=false, train_batch_size >= len(train); +# num_workers=0 avoids the single-batch IPC deadlock. +trainer: + max_epochs: 10 + +data: + train_batch_size: 1000000000 + eval_batch_size: 1024 + train_shuffle: false + train_drop_last: false + num_workers: 0 + +model: + optimizer: lbfgs + learning_rate: 1.0 + weight_decay: 1.0e-2 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false + +metadata: + run_name: Final Linear Classifier LBFGS ${dataset.name} + description: "Final LBFGS linear probe over frozen Virchow2 embeddings, exact full-batch solve of the convex objective on all training folds." diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml index 4d92561..3f11835 100644 --- a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml @@ -8,7 +8,7 @@ trainer: max_epochs: 10 data: - batch_size: 1000000000 + train_batch_size: 1000000000 train_shuffle: false train_drop_last: false num_workers: 0 diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml index bd3c10b..f857ccd 100644 --- a/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml @@ -8,7 +8,7 @@ trainer: max_epochs: 10 data: - batch_size: 1000000000 + train_batch_size: 1000000000 train_shuffle: false train_drop_last: false diff --git a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml new file mode 100644 index 0000000..a8d4ad9 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +defaults: + - /ml/task: final_linear_classifier + - _self_ + +# Predict over every test-split tile that intersects the tissue mask. This run +# is unlabeled: it writes predictions/maps for review and does not compute test +# metrics. +mode: predict +final_train_run_id: 0e2230c722134ce0985e09a18ccadf75 +checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt +checkpoint_weights_only: false + +tissue_embedding_run_id: 95a02c93c164415e94702ad5c83ccca2 +tissue_stats_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id} +tissue_stats_artifact_path: tissue_stats +tissue_column: tile_tissue_coverage +tissue_min: 0.0 + +test_embedding_uri: runs:/${tissue_embedding_run_id}/test/tiles +test_metadata_uri: runs:/${tissue_stats_run_id}/${tissue_stats_artifact_path}/test_tiles.parquet + +data: + # num_workers MUST be 0 for this config. UnlabeledEmbeddingTilesDataset and + # EmbeddingTilesDataset load the entire split into memory in __init__; using + # worker processes (num_workers > 0) can deadlock after multiprocessing/fork. + num_workers: 0 + predict: + _target_: ml.data.datasets.UnlabeledEmbeddingTilesDataset + embedding_uri: ${test_embedding_uri} + metadata_uri: ${test_metadata_uri} + tissue_column: ${tissue_column} + tissue_min: ${tissue_min} + +model: + optimizer: lbfgs + learning_rate: 1.0 + weight_decay: 1.0e-2 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false + +metadata: + run_name: Predict Linear Classifier tissue tiles ${dataset.name} + description: "Predict final linear classifier over all held-out test tiles intersecting tissue masks for external doctor review." + hyperparams: + tissue_embedding_run_id: ${tissue_embedding_run_id} + tissue_stats_run_id: ${tissue_stats_run_id} + tissue_stats_artifact_path: ${tissue_stats_artifact_path} + tissue_column: ${tissue_column} + tissue_min: ${tissue_min} diff --git a/configs/experiment/ml/linear_classifier_test_adamw.yaml b/configs/experiment/ml/linear_classifier_test_adamw.yaml new file mode 100644 index 0000000..ed214e5 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_test_adamw.yaml @@ -0,0 +1,32 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_final_adamw + - _self_ + +# Test the AdamW final checkpoint on the held-out test split. Same model +# architecture as the final run (required for state_dict load); optimizer +# fields are inert at test. +# +mode: test +final_train_run_id: a23e478b00b04da79cfbf4d91cada8cd +checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt +checkpoint_weights_only: false + +# num_workers MUST stay 0. EmbeddingTilesDataset loads the entire split into +# one in-memory numpy array in __init__ and __getitem__ is pure numpy indexing +# (no per-item IO), so workers give zero speedup; num_workers>0 forks the +# parent (pyarrow/mlflow/fsspec thread state + large array) and deadlocks +# before the first test batch. final_embedding_tiles defaults to 4; override here. +data: + num_workers: 0 + +trainer: + callbacks: + tiff_prediction_maps: + _target_: ml.callbacks.TiffPredictionMapWriter + slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet + artifact_path: prediction_maps_tiff + draw_region: central_stride + slide_selection: all + max_slides: null diff --git a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml new file mode 100644 index 0000000..c005c37 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_final_lbfgs + - override /ml/trainer: early_stopping + - _self_ + +# Test the LBFGS final checkpoint on the held-out test split. The full-batch +# train_batch_size=1e9 is a TRAINING requirement for the convex LBFGS solve +# only; at test there is no optimization, so use a normal batch to avoid +# loading the whole test set as one tensor (OOM). +# +# num_workers MUST stay 0. EmbeddingTilesDataset loads the entire split into +# one in-memory numpy array in __init__ and __getitem__ is pure numpy indexing +# (no per-item IO), so workers give zero speedup; num_workers>0 forks the +# parent (which holds pyarrow/mlflow/fsspec thread state + the large array), +# which deadlocks before the first test batch. +# Same model architecture as the final run (required for state_dict load). +mode: test +final_train_run_id: 0e2230c722134ce0985e09a18ccadf75 +checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt +checkpoint_weights_only: false + +data: + train_batch_size: 1024 + num_workers: 0 + +trainer: + callbacks: + tiff_prediction_maps: + _target_: ml.callbacks.TiffPredictionMapWriter + slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet + artifact_path: prediction_maps_tiff + draw_region: central_stride + slide_selection: all + max_slides: null diff --git a/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_05mpp.yaml b/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_05mpp.yaml new file mode 100644 index 0000000..4ea45ff --- /dev/null +++ b/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_05mpp.yaml @@ -0,0 +1,17 @@ +# @package _global_ + +defaults: + - /experiment/preprocessing/embeddings_virchow2_05mpp + - _self_ + +# Embeddings for every tile in the train/test tiling split that intersects the +# tissue mask. This is the tile universe used for doctor-review prediction maps. +splits: + - test +tile_source_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id} +tile_source_artifact_template: "tissue_stats/{split}_tiles.parquet" +tile_filter_column: tile_tissue_coverage + +metadata: + run_name: "Embeddings: ${model} tissue tiles" + description: "Tile embeddings using ${model} over held-out test split tiles with tile_tissue_coverage > 0." diff --git a/configs/ml.yaml b/configs/ml.yaml index bfde61e..c34ff6d 100644 --- a/configs/ml.yaml +++ b/configs/ml.yaml @@ -6,6 +6,7 @@ defaults: seed: ${random_seed:} mode: ??? checkpoint: null +checkpoint_weights_only: null trainer: {} diff --git a/configs/ml/data/final_embedding_tiles.yaml b/configs/ml/data/final_embedding_tiles.yaml new file mode 100644 index 0000000..826f7a2 --- /dev/null +++ b/configs/ml/data/final_embedding_tiles.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +data: + train_batch_size: 1024 + num_workers: 4 + train_shuffle: true + train_drop_last: false + + train: + _target_: ml.data.datasets.EmbeddingTilesDataset + embedding_uri: ${train_embedding_uri} + metadata_uri: ${train_metadata_uri} + class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} + + test: + _target_: ml.data.datasets.EmbeddingTilesDataset + embedding_uri: ${test_embedding_uri} + metadata_uri: ${test_metadata_uri} + class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} diff --git a/configs/ml/data/embedding.yaml b/configs/ml/data/kfold_embedding_tiles.yaml similarity index 97% rename from configs/ml/data/embedding.yaml rename to configs/ml/data/kfold_embedding_tiles.yaml index 40ff4b7..e1248a8 100644 --- a/configs/ml/data/embedding.yaml +++ b/configs/ml/data/kfold_embedding_tiles.yaml @@ -1,7 +1,7 @@ # @package _global_ data: - batch_size: 1024 + train_batch_size: 1024 num_workers: 4 train_shuffle: true train_drop_last: true diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index 4b4d9e8..2760ab7 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -11,15 +11,7 @@ model: class_indices: ${class_indices} - optimizer: adamw - learning_rate: 1.0e-4 - weight_decay: 0.0 - lbfgs: - max_iter: 100 - max_eval: null - tolerance_grad: 1.0e-7 - tolerance_change: 1.0e-9 - history_size: 100 - line_search_fn: strong_wolfe - accumulate_batches: 1 - accumulate_on_cpu: false + optimizer: ??? + learning_rate: ??? + weight_decay: ??? + lbfgs: null diff --git a/configs/ml/task/final_linear_classifier.yaml b/configs/ml/task/final_linear_classifier.yaml new file mode 100644 index 0000000..61afef7 --- /dev/null +++ b/configs/ml/task/final_linear_classifier.yaml @@ -0,0 +1,48 @@ +# @package _global_ + +defaults: + - /data: dataset + - /class_mapping: collapse_alterations_to_other + - /ml/trainer: final_with_prediction_maps + - /ml/data: final_embedding_tiles + - /ml/model: linear_classifier + - _self_ + +mode: fit + +embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} +filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + +train_embedding_uri: runs:/${embedding_run_id}/train/tiles +test_embedding_uri: runs:/${embedding_run_id}/test/tiles +train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet +test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet + +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.4 + Fat: 0.6 + Epithelium: 0.2 + Muscle: 0.5 + Other: 0.5 + +mlflow_artifact_path: linear_classifier_final + +metadata: + run_name: Final Linear Classifier ${dataset.name} + description: "Final linear probe over frozen Virchow2 embeddings trained on all training folds for ${trainer.max_epochs} epochs." + hyperparams: + embedding_run_id: ${embedding_run_id} + kfold_strategy: ${kfold_strategy} + kfold_run_id: ${kfold_run_id} + filter_tiles_run_id: ${filter_tiles_run_id} + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} + learning_rate: ${model.learning_rate} + weight_decay: ${model.weight_decay} + batch_size: ${data.train_batch_size} + max_epochs: ${trainer.max_epochs} diff --git a/configs/ml/linear_classifier.yaml b/configs/ml/task/kfold_linear_classifier.yaml similarity index 93% rename from configs/ml/linear_classifier.yaml rename to configs/ml/task/kfold_linear_classifier.yaml index d339337..aa73341 100644 --- a/configs/ml/linear_classifier.yaml +++ b/configs/ml/task/kfold_linear_classifier.yaml @@ -3,8 +3,8 @@ defaults: - /data: dataset - /class_mapping: collapse_alterations_to_other - - /ml/trainer: default - - /ml/data: embedding + - /ml/trainer: early_stopping + - /ml/data: kfold_embedding_tiles - /ml/model: linear_classifier - _self_ @@ -49,6 +49,6 @@ metadata: learning_rate: ${model.learning_rate} weight_decay: ${model.weight_decay} lbfgs: ${model.lbfgs} - batch_size: ${data.batch_size} + batch_size: ${data.train_batch_size} train_shuffle: ${data.train_shuffle} train_drop_last: ${data.train_drop_last} diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/early_stopping.yaml similarity index 100% rename from configs/ml/trainer/default.yaml rename to configs/ml/trainer/early_stopping.yaml diff --git a/configs/ml/trainer/final_with_prediction_maps.yaml b/configs/ml/trainer/final_with_prediction_maps.yaml new file mode 100644 index 0000000..c744d28 --- /dev/null +++ b/configs/ml/trainer/final_with_prediction_maps.yaml @@ -0,0 +1,30 @@ +# @package _global_ + +trainer: + max_epochs: 6 + accelerator: auto + devices: auto + precision: 32 + log_every_n_steps: 50 + deterministic: false + + callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + save_last: true + save_top_k: 0 + filename: "final-epoch={epoch}" + auto_insert_metric_name: false + lr_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch + prediction_writer: + _target_: ml.callbacks.ParquetPredictionWriter + output_filename: predictions.parquet + tiff_prediction_maps: + _target_: ml.callbacks.TiffPredictionMapWriter + slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet + artifact_path: prediction_maps_tiff + draw_region: central_stride + slide_selection: all + max_slides: null diff --git a/configs/preprocessing/embeddings.yaml b/configs/preprocessing/embeddings.yaml index db9ceaa..91400ac 100644 --- a/configs/preprocessing/embeddings.yaml +++ b/configs/preprocessing/embeddings.yaml @@ -5,6 +5,12 @@ output_dir: ${project_path}/embeddings concurrency: 512 block_size: 2048 rows_per_file: 5000 +splits: + - train + - test +tile_source_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} +tile_source_artifact_template: "filter_tiles/{split}_tiles.parquet" +tile_filter_column: null metadata: run_name: "Embeddings: ${model}" @@ -13,3 +19,7 @@ metadata: model: ${model} concurrency: ${concurrency} block_size: ${block_size} + splits: ${splits} + tile_source_run_id: ${tile_source_run_id} + tile_source_artifact_template: ${tile_source_artifact_template} + tile_filter_column: ${tile_filter_column} diff --git a/ml/__main__.py b/ml/__main__.py index 318c37c..b474b9e 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -3,6 +3,7 @@ import hydra import mlflow from lightning import seed_everything +from mlflow.artifacts import download_artifacts from omegaconf import DictConfig, OmegaConf from rationai.mlkit import Trainer, autolog from rationai.mlkit.lightning.loggers import MLFlowLogger @@ -28,9 +29,23 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: allowed_modes = ["fit", "test", "validate", "predict"] if config.mode not in allowed_modes: raise ValueError(f"Invalid mode {config.mode!r}. Allowed: {allowed_modes}") - getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) + run_kwargs = { + "datamodule": data, + "ckpt_path": _resolve_checkpoint(config.checkpoint), + } + if config.checkpoint_weights_only is not None: + run_kwargs["weights_only"] = config.checkpoint_weights_only + getattr(trainer, config.mode)(model, **run_kwargs) mlflow.end_run() +def _resolve_checkpoint(checkpoint: str | None) -> str | None: + if checkpoint is None: + return None + if checkpoint.startswith(("mlflow-artifacts:/", "runs:/")): + return download_artifacts(artifact_uri=checkpoint) + return checkpoint + + if __name__ == "__main__": main() diff --git a/ml/callbacks/__init__.py b/ml/callbacks/__init__.py index e9c20c4..ada21e8 100644 --- a/ml/callbacks/__init__.py +++ b/ml/callbacks/__init__.py @@ -1,4 +1,5 @@ from ml.callbacks.parquet_prediction_writer import ParquetPredictionWriter +from ml.callbacks.tiff_prediction_map_writer import TiffPredictionMapWriter -__all__ = ["ParquetPredictionWriter"] +__all__ = ["ParquetPredictionWriter", "TiffPredictionMapWriter"] diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py index a6f676b..68a6dcd 100644 --- a/ml/callbacks/parquet_prediction_writer.py +++ b/ml/callbacks/parquet_prediction_writer.py @@ -39,11 +39,15 @@ def write_on_epoch_end( ) slide_ids: list[str] = [] + xs: list[int] = [] + ys: list[int] = [] targets: list[int] = [] preds: list[int] = [] probs: list[np.ndarray] = [] for b in batches: slide_ids.extend(b["slide_id"]) + xs.extend(b["x"].tolist()) + ys.extend(b["y"].tolist()) targets.extend(b["target"].tolist()) preds.extend(b["pred"].tolist()) probs.append(b["probs"].numpy()) @@ -60,7 +64,9 @@ def write_on_epoch_end( else [f"prob_{i}" for i in range(prob_matrix.shape[1])] ) - df = pd.DataFrame({"slide_id": slide_ids, "target": targets, "pred": preds}) + df = pd.DataFrame( + {"slide_id": slide_ids, "x": xs, "y": ys, "target": targets, "pred": preds} + ) df = pd.concat([df, pd.DataFrame(prob_matrix, columns=prob_columns)], axis=1) out_path = Path(trainer.default_root_dir) / self.output_filename diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py new file mode 100644 index 0000000..e8548db --- /dev/null +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -0,0 +1,538 @@ +"""Write tile predictions as WSI-aligned BigTIFF masks.""" + +from collections.abc import Mapping +from pathlib import Path +from re import sub +from tempfile import TemporaryDirectory +from typing import Any, cast + +import lightning as pl +import mlflow +import numpy as np +import pandas as pd +import torch +from lightning.pytorch.callbacks import Callback +from mlflow.artifacts import download_artifacts + + +class TiffPredictionMapWriter(Callback): + """Collect test/predict batches and log per-slide prediction TIFFs. + + The output masks use the same coordinate space and MPP as the tiling + ``slides.parquet`` artifact. Class maps store the predicted class index as + ``uint8`` and use ``background_value`` for pixels without predictions. + """ + + def __init__( + self, + slides_uri: str, + artifact_path: str = "prediction_maps_tiff", + background_value: int = 255, + draw_region: str = "central_stride", + max_slides: int | None = None, + slide_selection: str = "all", + ) -> None: + super().__init__() + if draw_region != "central_stride": + raise ValueError( + "draw_region must be 'central_stride'; 'tile' is unsupported " + "for class maps (overlapping tiles would average categorical " + f"class indices). got {draw_region!r}" + ) + if slide_selection not in {"all", "worst"}: + raise ValueError( + "slide_selection must be either 'all' or 'worst', " + f"got {slide_selection!r}" + ) + if max_slides is not None and max_slides <= 0: + raise ValueError(f"max_slides must be positive or None, got {max_slides}") + self.slides_uri = slides_uri + self.artifact_path = artifact_path + self.background_value = background_value + self.draw_region = draw_region + self.max_slides = max_slides + self.slide_selection = slide_selection + self._batches: list[dict[str, Any]] = [] + self._written = False + + def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._batches.clear() + self._written = False + print("[TiffPredictionMapWriter] test loop started", flush=True) + + def on_predict_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + self._batches.clear() + self._written = False + + def on_test_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: torch.Tensor | Mapping[str, Any] | None, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if isinstance(outputs, Mapping): + self._batches.append(_to_cpu_batch(outputs)) + if trainer.global_rank == 0 and batch_idx % 50 == 0: + print( + f"[TiffPredictionMapWriter] test batch {batch_idx} " + f"({len(self._batches)} buffered)", + flush=True, + ) + + def on_predict_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: dict[str, Any] | None, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if outputs is not None: + self._batches.append(_to_cpu_batch(outputs)) + + def on_test_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + self._write_maps(trainer) + + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._write_maps(trainer) + + def on_predict_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + self._write_maps(trainer) + + def on_predict_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + self._write_maps(trainer) + + def _write_maps(self, trainer: pl.Trainer) -> None: + if self._written: + return + batches = _gather_batches(self._batches) + if trainer.global_rank != 0: + self._batches.clear() + return + if not batches: + print( + "[TiffPredictionMapWriter] no buffered batches at write time " + f"(local buffer={len(self._batches)}, " + f"dist_init={torch.distributed.is_available() and torch.distributed.is_initialized()}); " + "skipping", + flush=True, + ) + return + + predictions = _batches_to_dataframe(batches) + if predictions.empty: + print( + "[TiffPredictionMapWriter] predictions dataframe empty; skipping", + flush=True, + ) + return + + slides = pd.read_parquet(_resolve_uri(self.slides_uri)) + slides_by_id = { + str(row["id"]): cast("dict[str, Any]", row) + for row in slides.to_dict(orient="records") + } + + with TemporaryDirectory(dir=Path(trainer.default_root_dir)) as output_dir: + output_path = Path(output_dir) + Path(output_path, "pred").mkdir(parents=True, exist_ok=True) + Path(output_path, "prob").mkdir(parents=True, exist_ok=True) + + slide_groups = self._select_slide_groups(predictions) + print( + f"[TiffPredictionMapWriter] writing {len(slide_groups)} " + f"prediction map(s)", + flush=True, + ) + for index, (slide_id, slide_predictions) in enumerate( + slide_groups, start=1 + ): + slide = slides_by_id.get(str(slide_id)) + if slide is None: + raise KeyError( + f"slide_id {slide_id!r} not found in slides artifact " + f"{self.slides_uri!r}" + ) + print( + f"[TiffPredictionMapWriter] {index}/{len(slide_groups)} " + f"{Path(str(slide['path'])).name}", + flush=True, + ) + self._write_slide_maps( + slide, + slide_predictions, + output_path, + class_names=getattr(trainer.lightning_module, "class_names", None), + ) + + active = mlflow.active_run() + if active is not None: + print( + f"[TiffPredictionMapWriter] logging artifacts to " + f"{self.artifact_path}", + flush=True, + ) + mlflow.log_artifacts(output_dir, artifact_path=self.artifact_path) + + self._written = True + self._batches.clear() + + def _select_slide_groups( + self, predictions: pd.DataFrame + ) -> list[tuple[str, pd.DataFrame]]: + if self.slide_selection == "worst": + predictions = predictions.assign( + _correct=predictions["pred"] == predictions["target"] + ) + slide_ids = ( + predictions.groupby("slide_id", sort=False)["_correct"] + .mean() + .sort_values() + .index + ) + if self.max_slides is not None: + slide_ids = slide_ids[: self.max_slides] + selected = predictions[predictions["slide_id"].isin(slide_ids)] + return [ + (str(slide_id), slide_predictions.drop(columns=["_correct"])) + for slide_id, slide_predictions in selected.groupby( + "slide_id", sort=False + ) + ] + + groups = [ + (str(slide_id), slide_predictions) + for slide_id, slide_predictions in predictions.groupby( + "slide_id", sort=False + ) + ] + return groups[: self.max_slides] if self.max_slides is not None else groups + + def _write_slide_maps( + self, + slide: dict[str, Any], + predictions: pd.DataFrame, + output_path: Path, + class_names: list[str] | None, + ) -> None: + filename = _slide_prediction_filename(slide["path"]) + extent = (int(slide["extent_x"]), int(slide["extent_y"])) + tile_extent = (int(slide["tile_extent_x"]), int(slide["tile_extent_y"])) + stride = (int(slide["stride_x"]), int(slide["stride_y"])) + mpp = (float(slide["mpp_x"]), float(slide["mpp_y"])) + + xs = predictions["x"].to_numpy(dtype=np.int64) + ys = predictions["y"].to_numpy(dtype=np.int64) + probs = np.stack( + [np.asarray(prob, dtype=np.float32) for prob in predictions["probs"]] + ) + + pred_path = Path(output_path, "pred", filename) + print( + f"[TiffPredictionMapWriter] writing pred/{filename}", + flush=True, + ) + _write_class_map( + probs=probs, + xs=xs, + ys=ys, + extent=extent, + tile_extent=tile_extent, + stride=stride, + path=pred_path, + mpp=mpp, + background_value=self.background_value, + ) + print( + f"[TiffPredictionMapWriter] wrote pred/{filename}", + flush=True, + ) + + names = ( + class_names + if class_names is not None and len(class_names) == probs.shape[1] + else [f"class_{i}" for i in range(probs.shape[1])] + ) + _write_per_class_probability_maps( + probs=probs, + class_names=names, + xs=xs, + ys=ys, + extent=extent, + tile_extent=tile_extent, + stride=stride, + output_dir=Path(output_path, "prob"), + filename=filename, + mpp=mpp, + ) + + +class _ClassVoteAssembler: + """Confidence-weighted per-class accumulator over the full tile footprint. + + Mirrors ``HeatmapAssembler``'s GCD-compressed grid + ``(x, y)`` -> ROI + mapping, but accumulates the per-class softmax of every tile that covers a + cell instead of summing a categorical index. Overlapping strided tiles are + fused by ``argmax`` over the summed probabilities, so heavy stride overlap + (e.g. tile 224 / stride 112) is used instead of discarded. ``count == 0`` + marks never-covered cells as background. + """ + + def __init__( + self, + extent_x: int, + extent_y: int, + tile_x: int, + tile_y: int, + stride_x: int, + stride_y: int, + n_classes: int, + ) -> None: + from math import gcd + + self.cdx = gcd(stride_x, tile_x) + self.cdy = gcd(stride_y, tile_y) + self._sx = stride_x // self.cdx + self._sy = stride_y // self.cdy + self._tx = tile_x // self.cdx + self._ty = tile_y // self.cdy + self._gx = extent_x // self.cdx + self._gy = extent_y // self.cdy + self._n_classes = n_classes + self._acc = torch.zeros(n_classes, self._gy, self._gx, dtype=torch.float32) + self._count = torch.zeros(self._gy, self._gx, dtype=torch.int32) + + def update(self, probs: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor) -> None: + for prob, x, y in zip(probs, xs, ys, strict=False): + cx = int(x.item()) // self.cdx + cy = int(y.item()) // self.cdy + x0, y0 = cx * self._sx, cy * self._sy + x1 = min(x0 + self._tx, self._gx) + y1 = min(y0 + self._ty, self._gy) + if x1 <= x0 or y1 <= y0: + continue + self._acc[:, y0:y1, x0:x1] += prob[:, None, None] + self._count[y0:y1, x0:x1] += 1 + + def labels(self, background_value: int = 0) -> np.ndarray: + """Encode prediction labels like ``remap_annotation_masks``. + + Class ``i`` (0-based) maps to + ``round(255 * (i + 1) / n_classes)``. Never-covered pixels map to + ``background_value``. + + The reporting tool expects GT and prediction masks in the same + evenly-spread value space, so this must mirror that LUT exactly. + """ + idx = self._acc.argmax(0).numpy() + out = _spread_lut(self._n_classes)[idx].astype(np.uint8) + out[self._count.numpy() == 0] = _uint8_scalar(background_value) + return np.ascontiguousarray(out) + + +def _emit_mask( + grid: np.ndarray, + cdx: int, + cdy: int, + extent: tuple[int, int], + background_value: int, + path: Path, + mpp: tuple[float, float], +) -> None: + """Upscale the GCD-compressed grid back to WSI extent and write a BigTIFF. + + NEAREST resize by the GCD factor (tile footprint placed at its true + position, no recenter), background-padded to the full extent. + """ + import pyvips + from rationai.masks import write_big_tiff + + path.parent.mkdir(parents=True, exist_ok=True) + mask = pyvips.Image.new_from_array(grid).cast(pyvips.BandFormat.UCHAR) + mask = mask.resize(cdx, vscale=cdy, kernel=pyvips.enums.Kernel.NEAREST) + mask = mask.embed( + 0, + 0, + extent[0], + extent[1], + extend=pyvips.enums.Extend.BACKGROUND, + background=[background_value], + ) + write_big_tiff(mask, path, mpp[0], mpp[1]) + + +def _write_class_map( + probs: np.ndarray, + xs: np.ndarray, + ys: np.ndarray, + extent: tuple[int, int], + tile_extent: tuple[int, int], + stride: tuple[int, int], + path: Path, + mpp: tuple[float, float], + background_value: int = 0, +) -> None: + assembler = _ClassVoteAssembler( + extent[0], + extent[1], + tile_extent[0], + tile_extent[1], + stride[0], + stride[1], + n_classes=probs.shape[1], + ) + assembler.update( + torch.from_numpy(probs), + torch.from_numpy(xs), + torch.from_numpy(ys), + ) + _emit_mask( + assembler.labels(background_value), + assembler.cdx, + assembler.cdy, + extent, + int(_uint8_scalar(background_value)), + path, + mpp, + ) + + +def _write_per_class_probability_maps( + probs: np.ndarray, + class_names: list[str], + xs: np.ndarray, + ys: np.ndarray, + extent: tuple[int, int], + tile_extent: tuple[int, int], + stride: tuple[int, int], + output_dir: Path, + filename: str, + mpp: tuple[float, float], +) -> None: + """Write one grayscale probability map per class. + + Each class channel is assembled independently with HeatmapAssembler, so + overlapping tiles are averaged exactly like scalar heatmaps. Values are + encoded as uint8 probabilities in [0, 255], with never-covered pixels set + to 0. + """ + from rationai.masks.heatmap_assembler import HeatmapAssembler + + xs_t = torch.from_numpy(xs) + ys_t = torch.from_numpy(ys) + for class_idx, class_name in enumerate(class_names): + class_dir = _safe_filename(class_name) + print( + f"[TiffPredictionMapWriter] writing prob/{class_dir}/{filename}", + flush=True, + ) + assembler = HeatmapAssembler( + extent[0], + extent[1], + tile_extent[0], + tile_extent[1], + stride[0], + stride[1], + dtype=torch.float32, + ) + assembler.update( + torch.from_numpy(probs[:, class_idx].astype(np.float32, copy=False)), + xs_t, + ys_t, + ) + grid = np.clip(assembler.compute().numpy() * 255.0, 0, 255).astype(np.uint8) + grid[assembler._count.numpy() == 0] = 0 + _emit_mask( + np.ascontiguousarray(grid), + assembler.common_divisor_x, + assembler.common_divisor_y, + extent, + 0, + output_dir / class_dir / filename, + mpp, + ) + print( + f"[TiffPredictionMapWriter] wrote prob/{class_dir}/{filename}", + flush=True, + ) + + +def _gather_batches(batches: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return batches + + gathered: list[list[dict[str, Any]]] = [ + [] for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather_object(gathered, batches) + return [batch for rank_batches in gathered for batch in rank_batches] + + +def _to_cpu_batch(batch: Mapping[str, Any]) -> dict[str, Any]: + return { + key: value.detach().cpu() if isinstance(value, torch.Tensor) else value + for key, value in batch.items() + } + + +def _batches_to_dataframe(batches: list[dict[str, Any]]) -> pd.DataFrame: + rows: list[pd.DataFrame] = [] + for batch in batches: + frame = pd.DataFrame( + { + "slide_id": list(batch["slide_id"]), + "x": batch["x"].numpy(), + "y": batch["y"].numpy(), + "target": batch["target"].numpy(), + "pred": batch["pred"].numpy(), + } + ) + # per-row softmax vector kept as an object column so it survives the + # per-slide groupby; stacked back to (n, C) in the assembler. + frame["probs"] = list(batch["probs"].numpy()) + rows.append(frame) + return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame() + + +def _resolve_uri(uri: str) -> str: + if uri.startswith(("mlflow-artifacts:/", "runs:/")): + return download_artifacts(artifact_uri=uri) + return uri + + +def _safe_filename(value: str) -> str: + return sub(r"[^A-Za-z0-9 _.-]+", "_", value) + + +def _slide_prediction_filename(path: str | Path) -> str: + return Path(str(path)).with_suffix(".tiff").name + + +def _spread_lut(n_classes: int) -> np.ndarray: + """0-based class index -> evenly-spread uint8 value. + + Mirrors ``preprocessing/remap_annotation_masks.py`` exactly so prediction + masks share the GT value space the reporting tool expects: + class ``i`` -> ``round(255 * (i + 1) / n_classes)``. Index 0 of the + returned array is unused by callers (background is set separately). + """ + return np.array( + [round(255 * (i + 1) / n_classes) for i in range(n_classes)], + dtype=np.uint8, + ) + + +def _uint8_scalar(value: int) -> int: + return np.asarray(value).clip(0, np.iinfo(np.uint8).max).astype(np.uint8).item() diff --git a/ml/data/data_module.py b/ml/data/data_module.py index bfac118..b96df1e 100644 --- a/ml/data/data_module.py +++ b/ml/data/data_module.py @@ -17,14 +17,16 @@ class DataModule(LightningDataModule): def __init__( self, - batch_size: int, + train_batch_size: int, + eval_batch_size: int | None = None, num_workers: int = 0, train_shuffle: bool = True, train_drop_last: bool = True, **datasets: DictConfig, ) -> None: super().__init__() - self.batch_size = batch_size + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size or train_batch_size self.num_workers = num_workers self.train_shuffle = train_shuffle self.train_drop_last = train_drop_last @@ -34,9 +36,17 @@ def setup(self, stage: str) -> None: match stage: case "fit": self.train = instantiate(self.datasets["train"]) - self.val = instantiate(self.datasets["val"]) + self.val = ( + instantiate(self.datasets["val"]) + if self.datasets.get("val") is not None + else None + ) case "validate": - self.val = instantiate(self.datasets["val"]) + self.val = ( + instantiate(self.datasets["val"]) + if self.datasets.get("val") is not None + else None + ) case "test": self.test = instantiate(self.datasets["test"]) case "predict": @@ -48,7 +58,7 @@ def setup(self, stage: str) -> None: def train_dataloader(self) -> Iterable[Input]: return DataLoader( self.train, - batch_size=self.batch_size, + batch_size=self.train_batch_size, shuffle=self.train_shuffle, drop_last=self.train_drop_last, num_workers=self.num_workers, @@ -56,19 +66,21 @@ def train_dataloader(self) -> Iterable[Input]: ) def val_dataloader(self) -> Iterable[Input]: + if self.val is None: + return [] return DataLoader( self.val, - batch_size=self.batch_size, + batch_size=self.eval_batch_size, num_workers=self.num_workers, persistent_workers=self.num_workers > 0, ) def test_dataloader(self) -> Iterable[Input]: return DataLoader( - self.test, batch_size=self.batch_size, num_workers=self.num_workers + self.test, batch_size=self.eval_batch_size, num_workers=self.num_workers ) def predict_dataloader(self) -> Iterable[Input]: return DataLoader( - self.predict, batch_size=self.batch_size, num_workers=self.num_workers + self.predict, batch_size=self.eval_batch_size, num_workers=self.num_workers ) diff --git a/ml/data/datasets/__init__.py b/ml/data/datasets/__init__.py index cd2f91a..f3a5b47 100644 --- a/ml/data/datasets/__init__.py +++ b/ml/data/datasets/__init__.py @@ -1,4 +1,7 @@ -from ml.data.datasets.embedding_tiles import EmbeddingTilesDataset +from ml.data.datasets.embedding_tiles import ( + EmbeddingTilesDataset, + UnlabeledEmbeddingTilesDataset, +) -__all__ = ["EmbeddingTilesDataset"] +__all__ = ["EmbeddingTilesDataset", "UnlabeledEmbeddingTilesDataset"] diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 5160ba1..ba00eb0 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -2,11 +2,13 @@ Joins precomputed tile embeddings with tile metadata (k-fold parquet for train, filter_tiles parquet for test) and applies tissue + per-class thresholds at -load time to produce ``(embedding, class_index, slide_id)`` triples. +load time to produce ``(embedding, class_index, slide_id, x, y)`` samples. """ +from collections.abc import Callable from functools import cache from pathlib import Path +from time import perf_counter import numpy as np import pandas as pd @@ -19,7 +21,41 @@ from ml.typing import Sample -class EmbeddingTilesDataset(Dataset[Sample]): +class _BaseEmbeddingTilesDataset(Dataset[Sample]): + def __init__( + self, + embedding_uri: str | Path, + meta_df: pd.DataFrame, + diag: Callable[[str], None], + ) -> None: + diag(f"metadata filtered: {len(meta_df)} rows; reading embeddings") + joined_keys, embeddings = _load_embeddings_and_join( + embedding_uri, meta_df, diag + ) + self.embeddings = embeddings + self.labels = self._labels_from_joined_keys(joined_keys) + self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() + self.xs = joined_keys.column("x").to_pandas().to_numpy(dtype=np.int64) + self.ys = joined_keys.column("y").to_pandas().to_numpy(dtype=np.int64) + diag(f"dataset ready: {len(self.labels)} samples, dim={embeddings.shape[1]}") + + def __len__(self) -> int: + return len(self.labels) + + def __getitem__(self, idx: int) -> Sample: + return ( + torch.from_numpy(self.embeddings[idx]), + int(self.labels[idx]), + str(self.slide_ids[idx]), + int(self.xs[idx]), + int(self.ys[idx]), + ) + + def _labels_from_joined_keys(self, joined_keys: pa.Table) -> np.ndarray: + raise NotImplementedError + + +class EmbeddingTilesDataset(_BaseEmbeddingTilesDataset): """Tile-level embedding dataset with on-the-fly filtering and labeling. Inner-joins ``embedding`` parquet with ``metadata`` parquet on @@ -43,6 +79,9 @@ def __init__( include_folds: list[int] | None = None, exclude_folds: list[int] | None = None, ) -> None: + self.class_indices = class_indices + diag = _make_diag(type(self).__name__) + diag("filtering metadata") meta_df = self._filter_metadata( metadata_uri, thresholds, @@ -50,81 +89,16 @@ def __init__( include_folds, exclude_folds, ) + super().__init__(embedding_uri, meta_df, diag) - emb_dir = self._resolve_uri(embedding_uri) - emb_table = pads.dataset(emb_dir, format="parquet").to_table( - columns=["slide_id", "x", "y", "embedding"] - ) - - emb_col = emb_table.column("embedding") - if pa.types.is_list(emb_col.type): - target_type = pa.large_list(emb_col.type.value_type) - emb_col = pa.chunked_array( - [c.cast(target_type) for c in emb_col.chunks], type=target_type - ) - - emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) - emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) - del emb_table - - meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) - del meta_df - joined_keys = meta_table.join( - emb_keys, keys=["slide_id", "x", "y"], join_type="inner" - ) - del emb_keys, meta_table - if joined_keys.num_rows == 0: - raise RuntimeError("inner join with embeddings produced empty dataset") - - _idx_col = joined_keys.column("_emb_idx") - if isinstance(_idx_col, pa.ChunkedArray): - _idx_col = _idx_col.combine_chunks() - indices_np = _idx_col.to_numpy() - - first_chunk = emb_col.chunks[0] - embedding_dim = len(first_chunk.values) // len(first_chunk) - - # sort indices for sequential per-chunk access; restore order afterwards - sort_order = np.argsort(indices_np) - sorted_indices = indices_np[sort_order] - - chunk_offsets = np.concatenate( - [[0], np.cumsum([len(c) for c in emb_col.chunks])] - ) - embeddings = np.empty((len(indices_np), embedding_dim), dtype=np.float32) - for ci, chunk in enumerate(emb_col.chunks): - lo, hi = chunk_offsets[ci], chunk_offsets[ci + 1] - mask = (sorted_indices >= lo) & (sorted_indices < hi) - if not mask.any(): - continue - local_idx = sorted_indices[mask] - lo - chunk_np = ( - chunk.values.to_numpy(zero_copy_only=False) - .reshape(len(chunk), embedding_dim) - .astype(np.float32) - ) - embeddings[sort_order[mask]] = chunk_np[local_idx] - del emb_col - - self.embeddings = embeddings + def _labels_from_joined_keys(self, joined_keys: pa.Table) -> np.ndarray: labels = joined_keys.column("label").to_pandas() - unknown = set(labels.unique()) - set(class_indices.keys()) + unknown = set(labels.unique()) - set(self.class_indices.keys()) if unknown: raise ValueError( f"labels in data not present in class_indices: {sorted(unknown)}" ) - self.labels = labels.map(class_indices).to_numpy(dtype=np.int64) - self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() - - def __len__(self) -> int: - return len(self.labels) - - def __getitem__(self, idx: int) -> Sample: - return ( - torch.from_numpy(self.embeddings[idx]), - int(self.labels[idx]), - str(self.slide_ids[idx]), - ) + return labels.map(self.class_indices).to_numpy(dtype=np.int64) @staticmethod def _filter_metadata( @@ -134,7 +108,7 @@ def _filter_metadata( include_folds: list[int] | None, exclude_folds: list[int] | None, ) -> pd.DataFrame: - local = EmbeddingTilesDataset._resolve_uri(metadata_uri) + local = _resolve_uri(metadata_uri) df = pd.read_parquet(local) roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] @@ -185,13 +159,126 @@ def _filter_metadata( return df[["slide_id", "x", "y", "label"]] - @staticmethod - def _resolve_uri(path_or_uri: str | Path) -> str: - return EmbeddingTilesDataset._resolve_uri_cached(str(path_or_uri)) + +def _load_embeddings_and_join( + embedding_uri: str | Path, + meta_df: pd.DataFrame, + diag: Callable[[str], None], +) -> tuple[pa.Table, np.ndarray]: + emb_dir = _resolve_uri(embedding_uri) + emb_table = pads.dataset(emb_dir, format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) + diag(f"embedding table loaded: {emb_table.num_rows} rows") + + emb_col = emb_table.column("embedding") + if pa.types.is_list(emb_col.type): + target_type = pa.large_list(emb_col.type.value_type) + emb_col = pa.chunked_array( + [c.cast(target_type) for c in emb_col.chunks], type=target_type + ) + + emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) + emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) + del emb_table + + meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) + del meta_df + joined_keys = meta_table.join( + emb_keys, keys=["slide_id", "x", "y"], join_type="inner" + ) + del emb_keys, meta_table + if joined_keys.num_rows == 0: + raise RuntimeError("inner join with embeddings produced empty dataset") + diag(f"join done: {joined_keys.num_rows} matched rows; filling embeddings") + + _idx_col = joined_keys.column("_emb_idx") + if isinstance(_idx_col, pa.ChunkedArray): + _idx_col = _idx_col.combine_chunks() + indices_np = _idx_col.to_numpy() + + first_chunk = emb_col.chunks[0] + embedding_dim = len(first_chunk.values) // len(first_chunk) + + # Sort indices for sequential per-chunk access; restore order afterwards. + sort_order = np.argsort(indices_np) + sorted_indices = indices_np[sort_order] + + chunk_offsets = np.concatenate([[0], np.cumsum([len(c) for c in emb_col.chunks])]) + embeddings = np.empty((len(indices_np), embedding_dim), dtype=np.float32) + for ci, chunk in enumerate(emb_col.chunks): + lo, hi = chunk_offsets[ci], chunk_offsets[ci + 1] + mask = (sorted_indices >= lo) & (sorted_indices < hi) + if not mask.any(): + continue + local_idx = sorted_indices[mask] - lo + chunk_np = ( + chunk.values.to_numpy(zero_copy_only=False) + .reshape(len(chunk), embedding_dim) + .astype(np.float32) + ) + embeddings[sort_order[mask]] = chunk_np[local_idx] + del emb_col + + return joined_keys, embeddings + + +class UnlabeledEmbeddingTilesDataset(_BaseEmbeddingTilesDataset): + """Tile-embedding dataset for prediction over tiles without class labels.""" + + def __init__( + self, + embedding_uri: str | Path, + metadata_uri: str | Path, + tissue_column: str = "tile_tissue_coverage", + tissue_min: float = 0.0, + label_value: int = -1, + ) -> None: + self.label_value = label_value + diag = _make_diag(type(self).__name__) + diag("filtering metadata") + meta_df = self._filter_metadata(metadata_uri, tissue_column, tissue_min) + super().__init__(embedding_uri, meta_df, diag) + + def _labels_from_joined_keys(self, joined_keys: pa.Table) -> np.ndarray: + return np.full(joined_keys.num_rows, self.label_value, dtype=np.int64) @staticmethod - @cache - def _resolve_uri_cached(uri: str) -> str: - if uri.startswith(("mlflow-artifacts:/", "runs:/")): - return download_artifacts(artifact_uri=uri) - return uri + def _filter_metadata( + metadata_uri: str | Path, + tissue_column: str, + tissue_min: float, + ) -> pd.DataFrame: + local = _resolve_uri(metadata_uri) + columns = ["slide_id", "x", "y", tissue_column] + df = pd.read_parquet(local, columns=columns) + if tissue_column not in df.columns: + raise ValueError( + f"metadata parquet has no {tissue_column!r} column; cannot filter" + ) + df = df.loc[df[tissue_column] > tissue_min, ["slide_id", "x", "y"]] + if df.empty: + raise RuntimeError( + f"all tiles dropped by {tissue_column} > {tissue_min} filter" + ) + return df + + +def _resolve_uri(path_or_uri: str | Path) -> str: + return _resolve_uri_cached(str(path_or_uri)) + + +def _make_diag(dataset_name: str) -> Callable[[str], None]: + t0 = perf_counter() + + def _diag(msg: str) -> None: + print(f"[{dataset_name} +{perf_counter() - t0:6.1f}s] {msg}", flush=True) + + return _diag + + +@cache +def _resolve_uri_cached(uri: str) -> str: + if uri.startswith(("mlflow-artifacts:/", "runs:/")): + return download_artifacts(artifact_uri=uri) + return uri diff --git a/ml/meta_arch.py b/ml/meta_arch.py index ab6882e..e1baae3 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -56,6 +56,9 @@ def __init__( n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) ] num_classes = len(self.class_names) + # Placeholder weight so `criterion.weight` always exists in state_dict. + # setup(stage="fit") overrides with class-balanced weights; checkpoints + # then load with strict=True regardless of stage. self.criterion = nn.CrossEntropyLoss(weight=torch.ones(num_classes)) macro_metrics = MetricCollection( @@ -101,6 +104,8 @@ def setup(self, stage: str) -> None: ) for cls, w in zip(self.class_names, weights.tolist(), strict=True): mlflow.log_metric(f"class_weight/{cls}", w) + # Non-fit stages keep the placeholder ones-weight criterion from + # __init__ so `criterion.weight` stays in state_dict for strict load. def forward(self, x: Tensor) -> Outputs: features = self.backbone(x) @@ -110,7 +115,7 @@ def training_step(self, batch: Input, batch_idx: int) -> Tensor: if self.hparams["optimizer"] == "lbfgs": return self._lbfgs_training_step(batch, batch_idx) - inputs, targets, _ = batch + inputs, targets, *_ = batch outputs = self(inputs) loss = self.criterion(outputs, targets) self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) @@ -129,7 +134,7 @@ def on_before_optimizer_step(self, optimizer: Optimizer) -> None: ) def validation_step(self, batch: Input, batch_idx: int) -> None: - inputs, targets, _ = batch + inputs, targets, *_ = batch outputs = self(inputs) loss = self.criterion(outputs, targets) self.log("validation/loss", loss, on_epoch=True, prog_bar=True) @@ -142,8 +147,8 @@ def on_validation_epoch_end(self) -> None: self._log_per_class(self.val_per_class, "validation") self._log_confmat(self.val_confmat, "validation") - def test_step(self, batch: Input, batch_idx: int) -> None: - inputs, targets, slide_ids = batch + def test_step(self, batch: Input, batch_idx: int) -> dict[str, Any]: + inputs, targets, slide_ids, xs, ys = batch outputs = self(inputs) self.test_metrics.update(outputs, targets) self.test_per_class.update(outputs, targets) @@ -155,6 +160,14 @@ def test_step(self, batch: Input, batch_idx: int) -> None: for slide_id, ok in zip(slide_ids, correct, strict=True): self._test_slide_correct[slide_id] += int(ok) self._test_slide_total[slide_id] += 1 + return { + "slide_id": list(slide_ids), + "x": xs.cpu(), + "y": ys.cpu(), + "target": targets.cpu(), + "pred": preds.cpu(), + "probs": outputs.softmax(dim=1).cpu(), + } def on_test_epoch_end(self) -> None: self._log_per_class(self.test_per_class, "test") @@ -166,12 +179,14 @@ def on_test_epoch_end(self) -> None: def predict_step( self, batch: Input, batch_idx: int, dataloader_idx: int = 0 ) -> dict[str, Any]: - inputs, targets, slide_ids = batch + inputs, targets, slide_ids, xs, ys = batch outputs = self(inputs) probs = outputs.softmax(dim=1) preds = outputs.argmax(dim=1) return { "slide_id": list(slide_ids), + "x": xs.cpu(), + "y": ys.cpu(), "target": targets.cpu(), "pred": preds.cpu(), "probs": probs.cpu(), @@ -198,7 +213,7 @@ def configure_optimizers(self) -> Optimizer: ) def _lbfgs_training_step(self, batch: Input, batch_idx: int) -> Tensor: - inputs, targets, _ = batch + inputs, targets, *_ = batch self._lbfgs_batches.append(self._prepare_lbfgs_batch(inputs, targets)) lbfgs = self.hparams.get("lbfgs") or {} accumulation_steps = int(lbfgs.get("accumulate_batches", 1)) @@ -260,9 +275,9 @@ def _prepare_lbfgs_batch( def _validate_lbfgs_full_batch(self, datamodule: Any, train_size: int) -> None: lbfgs = self.hparams.get("lbfgs") or {} - batch_size = int(datamodule.batch_size) + train_batch_size = int(datamodule.train_batch_size) accumulation_steps = int(lbfgs.get("accumulate_batches", 1)) - effective_batch_size = batch_size * accumulation_steps + effective_batch_size = train_batch_size * accumulation_steps if datamodule.train_shuffle: raise ValueError("LBFGS requires data.train_shuffle=false.") @@ -271,9 +286,9 @@ def _validate_lbfgs_full_batch(self, datamodule: Any, train_size: int) -> None: if effective_batch_size < train_size: raise ValueError( "LBFGS requires a deterministic full-batch objective. Set " - "data.batch_size >= len(train) or set " + "data.train_batch_size >= len(train) or set " "model.lbfgs.accumulate_batches >= ceil(len(train) / " - "data.batch_size). Current effective batch size is " + "data.train_batch_size). Current effective batch size is " f"{effective_batch_size} for {train_size} training samples." ) diff --git a/ml/typing.py b/ml/typing.py index 7060f5e..5ad219c 100644 --- a/ml/typing.py +++ b/ml/typing.py @@ -1,6 +1,6 @@ import torch -type Sample = tuple[torch.Tensor, int, str] -type Input = tuple[torch.Tensor, torch.Tensor, list[str]] +type Sample = tuple[torch.Tensor, int, str, int, int] +type Input = tuple[torch.Tensor, torch.Tensor, list[str], torch.Tensor, torch.Tensor] type Outputs = torch.Tensor diff --git a/preprocessing/embeddings.py b/preprocessing/embeddings.py index de98143..7ad74dd 100644 --- a/preprocessing/embeddings.py +++ b/preprocessing/embeddings.py @@ -55,7 +55,15 @@ async def __call__(self, row: dict[str, Any]) -> dict[str, Any]: @hydra.main(config_path="../configs", config_name="preprocessing", version_base=None) @autolog def main(config: DictConfig, logger: MLFlowLogger) -> None: - for name in ["train", "test"]: + tile_source_run_id = config.get( + "tile_source_run_id", config.dataset.mlflow_artifacts.filter_tiles_run_id + ) + tile_source_artifact_template = config.get( + "tile_source_artifact_template", "filter_tiles/{split}_tiles.parquet" + ) + tile_filter_column = config.get("tile_filter_column") + + for name in config.get("splits", ["train", "test"]): split_folder = Path( mlflow.artifacts.download_artifacts( run_id=config.dataset.mlflow_artifacts.tiling_run_id, @@ -69,19 +77,35 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: tiles_path = Path( mlflow.artifacts.download_artifacts( - run_id=config.dataset.mlflow_artifacts.filter_tiles_run_id, - artifact_path=f"filter_tiles/{name}_tiles.parquet", + run_id=tile_source_run_id, + artifact_path=tile_source_artifact_template.format(split=name), ) ) - num_rows = pads.dataset(str(tiles_path), format="parquet").count_rows() + row_filter = ( + pads.field(tile_filter_column) > 0 + if tile_filter_column is not None + else None + ) + num_rows = pads.dataset(str(tiles_path), format="parquet").count_rows( + filter=row_filter + ) num_blocks = max(1, num_rows // config.block_size) + columns = ["slide_id", "x", "y"] + if tile_filter_column is not None: + columns.append(tile_filter_column) ds = ray.data.read_parquet( str(tiles_path), - columns=["slide_id", "x", "y"], + columns=columns, ray_remote_args={"memory": 8 * 1024**3}, override_num_blocks=num_blocks, - ).map( + ) + if tile_filter_column is not None: + ds = ds.filter( + lambda row, c: row[c] > 0, fn_kwargs={"c": tile_filter_column} + ) + ds = ds.drop_columns([tile_filter_column]) + ds = ds.map( lambda row, si: {**row, **si[row["slide_id"]]}, fn_kwargs={"si": slide_info}, ) diff --git a/scripts/submit_test_linear.py b/scripts/submit_test_linear.py new file mode 100644 index 0000000..3eccdea --- /dev/null +++ b/scripts/submit_test_linear.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-test-linear-final", + username=..., + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml +experiment=...", + ], + storage=[storage.secure.PROJECTS], +) diff --git a/scripts/submit_train_linear_final.py b/scripts/submit_train_linear_final.py new file mode 100644 index 0000000..b4c3940 --- /dev/null +++ b/scripts/submit_train_linear_final.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-linear-final", + username=..., + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml +experiment=...", + ], + storage=[storage.secure.PROJECTS], +) diff --git a/scripts/submit_train_linear_probe.py b/scripts/submit_train_linear_probe.py index 3f7ecb1..dcfc7d9 100644 --- a/scripts/submit_train_linear_probe.py +++ b/scripts/submit_train_linear_probe.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-train-linear", - username="vcifka", + username=..., cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", + "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=ml/linear_classifier_stratified_group_kfold val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", + "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", ], storage=[storage.secure.PROJECTS], )