From 24668c3a8204ce086bd09918f7f3f8f9846dcdaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 1 May 2026 16:23:34 +0200 Subject: [PATCH 001/107] feat: create ml pipeline for linear probe --- configs/ml/linear_probe.yaml | 40 ++++++++++++++ ml/__init__.py | 0 ml/data/__init__.py | 0 ml/data/embeddings_datamodule.py | 92 ++++++++++++++++++++++++++++++++ ml/models/__init__.py | 0 ml/models/linear_probe.py | 57 ++++++++++++++++++++ ml/train.py | 29 ++++++++++ 7 files changed, 218 insertions(+) create mode 100644 configs/ml/linear_probe.yaml create mode 100644 ml/__init__.py create mode 100644 ml/data/__init__.py create mode 100644 ml/data/embeddings_datamodule.py create mode 100644 ml/models/__init__.py create mode 100644 ml/models/linear_probe.py create mode 100644 ml/train.py diff --git a/configs/ml/linear_probe.yaml b/configs/ml/linear_probe.yaml new file mode 100644 index 00000000..318af731 --- /dev/null +++ b/configs/ml/linear_probe.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +defaults: + - /class_mapping: standard + +mode: fit + +embed_dim: ??? +embeddings_run_id: ??? + +trainer: + _target_: rationai.mlkit.lightning.Trainer + max_epochs: 30 + accelerator: auto + devices: 1 + log_every_n_steps: 50 + +data: + _target_: ml.data.embeddings_datamodule.EmbeddingsDataModule + train_dir: ${project_path}/embeddings/${embeddings_run_id}/train/tiles + test_dir: ${project_path}/embeddings/${embeddings_run_id}/test/tiles + class_names: ${class_mapping.class_names} + batch_size: 1024 + num_workers: 4 + val_fraction: 0.1 + +model: + _target_: ml.models.linear_probe.LinearProbe + embed_dim: ${embed_dim} + num_classes: ${len:${class_mapping.class_names}} + lr: 1e-3 + weight_decay: 0.0 + +metadata: + run_name: "Linear probe (embed_dim=${embed_dim})" + description: Linear probe on cached embeddings + hyperparams: + embed_dim: ${embed_dim} + lr: ${model.lr} + batch_size: ${data.batch_size} diff --git a/ml/__init__.py b/ml/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/data/__init__.py b/ml/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py new file mode 100644 index 00000000..13b0e0d2 --- /dev/null +++ b/ml/data/embeddings_datamodule.py @@ -0,0 +1,92 @@ +from pathlib import Path + +import lightning as pl +import numpy as np +import pyarrow.dataset as pads +import torch +from torch.utils.data import DataLoader, Dataset + + +class EmbeddingsDataset(Dataset): + def __init__( + self, + parquet_dir: str | Path, + class_names: list[str], + coverage_prefix: str = "roi_coverage", + ) -> None: + ds = pads.dataset(str(parquet_dir), format="parquet") + cols = ["embedding"] + [f"{coverage_prefix}_{c}" for c in class_names] + table = ds.to_table(columns=cols) + + self.embeddings = np.stack( + table.column("embedding").to_numpy(zero_copy_only=False) + ) + coverages = np.stack( + [table.column(f"{coverage_prefix}_{c}").to_numpy() for c in class_names], + axis=1, + ) + # hard labels via argmax over coverages; swap for soft targets if needed + self.labels = coverages.argmax(axis=1) + + def __len__(self) -> int: + return len(self.labels) + + def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return ( + torch.from_numpy(self.embeddings[idx]).float(), + torch.tensor(self.labels[idx], dtype=torch.long), + ) + + +class EmbeddingsDataModule(pl.LightningDataModule): + def __init__( + self, + train_dir: str, + test_dir: str, + class_names: list[str], + batch_size: int = 1024, + num_workers: int = 4, + val_fraction: float = 0.1, + coverage_prefix: str = "roi_coverage", + ) -> None: + super().__init__() + self.save_hyperparameters() + + def setup(self, stage: str) -> None: + full_train = EmbeddingsDataset( + self.hparams.train_dir, + self.hparams.class_names, + self.hparams.coverage_prefix, + ) + n_val = int(len(full_train) * self.hparams.val_fraction) + n_train = len(full_train) - n_val + self.train_set, self.val_set = torch.utils.data.random_split( + full_train, [n_train, n_val], generator=torch.Generator().manual_seed(0) + ) + self.test_set = EmbeddingsDataset( + self.hparams.test_dir, + self.hparams.class_names, + self.hparams.coverage_prefix, + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_set, + batch_size=self.hparams.batch_size, + shuffle=True, + num_workers=self.hparams.num_workers, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_set, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + ) + + def test_dataloader(self) -> DataLoader: + return DataLoader( + self.test_set, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + ) diff --git a/ml/models/__init__.py b/ml/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/models/linear_probe.py b/ml/models/linear_probe.py new file mode 100644 index 00000000..667df1c6 --- /dev/null +++ b/ml/models/linear_probe.py @@ -0,0 +1,57 @@ +import lightning as pl +import torch +import torch.nn.functional as F +from torch import nn, optim +from torchmetrics import Accuracy, F1Score, MetricCollection + + +class LinearProbe(pl.LightningModule): + def __init__( + self, + embed_dim: int, + num_classes: int, + lr: float = 1e-3, + weight_decay: float = 0.0, + ) -> None: + super().__init__() + self.save_hyperparameters() + self.head = nn.Linear(embed_dim, num_classes) + + metrics = MetricCollection( + { + "acc": Accuracy(task="multiclass", num_classes=num_classes), + "f1_macro": F1Score( + task="multiclass", num_classes=num_classes, average="macro" + ), + } + ) + self.train_metrics = metrics.clone(prefix="train/") + self.val_metrics = metrics.clone(prefix="val/") + self.test_metrics = metrics.clone(prefix="test/") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.head(x) + + def _step(self, batch, metrics, log_prefix: str) -> torch.Tensor: + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + self.log(f"{log_prefix}/loss", loss, prog_bar=True) + self.log_dict(metrics(logits, y), prog_bar=True) + return loss + + def training_step(self, batch, _): + return self._step(batch, self.train_metrics, "train") + + def validation_step(self, batch, _): + return self._step(batch, self.val_metrics, "val") + + def test_step(self, batch, _): + return self._step(batch, self.test_metrics, "test") + + def configure_optimizers(self): + return optim.AdamW( + self.parameters(), + lr=self.hparams.lr, + weight_decay=self.hparams.weight_decay, + ) diff --git a/ml/train.py b/ml/train.py new file mode 100644 index 00000000..0cb3213b --- /dev/null +++ b/ml/train.py @@ -0,0 +1,29 @@ +import hydra +import lightning as pl +from hydra.utils import instantiate +from omegaconf import DictConfig +from rationai.mlkit import autolog, with_cli_args +from rationai.mlkit.lightning.loggers import MLFlowLogger + + +@with_cli_args(["+ml=linear_probe"]) +@hydra.main(config_path="../configs", config_name="ml", version_base=None) +@autolog +def main(config: DictConfig, logger: MLFlowLogger) -> None: + pl.seed_everything(config.seed) + + datamodule: pl.LightningDataModule = instantiate(config.data) + model: pl.LightningModule = instantiate(config.model) + trainer: pl.Trainer = instantiate(config.trainer, logger=logger) + + if config.mode == "fit": + trainer.fit(model, datamodule=datamodule, ckpt_path=config.checkpoint) + trainer.test(model, datamodule=datamodule) + elif config.mode == "test": + trainer.test(model, datamodule=datamodule, ckpt_path=config.checkpoint) + else: + raise ValueError(f"Unknown mode: {config.mode}") + + +if __name__ == "__main__": + main() From f340038f1e114b8013b779138090a9a2a79f3440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 4 May 2026 20:10:46 +0200 Subject: [PATCH 002/107] refactor(ml): switch DataModule to HF datasets with fold-based split Use the `label` and `fold` columns produced by the upstream k-fold split instead of deriving labels from coverage columns and randomly splitting val. Memory-mapped via HuggingFace datasets so the full embedding parquet no longer has to fit in numpy. Co-Authored-By: Claude Opus 4.7 --- configs/ml/linear_probe.yaml | 6 +- ml/data/embeddings_datamodule.py | 109 +++++++++++++------------------ 2 files changed, 51 insertions(+), 64 deletions(-) diff --git a/configs/ml/linear_probe.yaml b/configs/ml/linear_probe.yaml index 318af731..c8697fc5 100644 --- a/configs/ml/linear_probe.yaml +++ b/configs/ml/linear_probe.yaml @@ -15,14 +15,17 @@ trainer: devices: 1 log_every_n_steps: 50 +val_fold: 0 + data: _target_: ml.data.embeddings_datamodule.EmbeddingsDataModule train_dir: ${project_path}/embeddings/${embeddings_run_id}/train/tiles test_dir: ${project_path}/embeddings/${embeddings_run_id}/test/tiles class_names: ${class_mapping.class_names} + val_fold: ${val_fold} batch_size: 1024 num_workers: 4 - val_fraction: 0.1 + tissue_prop_min: 0.0 model: _target_: ml.models.linear_probe.LinearProbe @@ -38,3 +41,4 @@ metadata: embed_dim: ${embed_dim} lr: ${model.lr} batch_size: ${data.batch_size} + val_fold: ${val_fold} diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py index 13b0e0d2..8175cec7 100644 --- a/ml/data/embeddings_datamodule.py +++ b/ml/data/embeddings_datamodule.py @@ -1,92 +1,75 @@ from pathlib import Path import lightning as pl -import numpy as np -import pyarrow.dataset as pads -import torch -from torch.utils.data import DataLoader, Dataset +from datasets import load_dataset +from torch.utils.data import DataLoader -class EmbeddingsDataset(Dataset): - def __init__( - self, - parquet_dir: str | Path, - class_names: list[str], - coverage_prefix: str = "roi_coverage", - ) -> None: - ds = pads.dataset(str(parquet_dir), format="parquet") - cols = ["embedding"] + [f"{coverage_prefix}_{c}" for c in class_names] - table = ds.to_table(columns=cols) - - self.embeddings = np.stack( - table.column("embedding").to_numpy(zero_copy_only=False) - ) - coverages = np.stack( - [table.column(f"{coverage_prefix}_{c}").to_numpy() for c in class_names], - axis=1, - ) - # hard labels via argmax over coverages; swap for soft targets if needed - self.labels = coverages.argmax(axis=1) - - def __len__(self) -> int: - return len(self.labels) - - def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - return ( - torch.from_numpy(self.embeddings[idx]).float(), - torch.tensor(self.labels[idx], dtype=torch.long), - ) +class EmbeddingsDataModule(pl.LightningDataModule): + """Linear-probe data module backed by HuggingFace datasets. + Expects parquet files with columns: `embedding` (list[float]), `label` (str), + `fold` (int, train side only), and `tissue_prop` (float). + """ -class EmbeddingsDataModule(pl.LightningDataModule): def __init__( self, train_dir: str, test_dir: str, class_names: list[str], + val_fold: int = 0, batch_size: int = 1024, num_workers: int = 4, - val_fraction: float = 0.1, - coverage_prefix: str = "roi_coverage", + tissue_prop_min: float = 0.0, ) -> None: super().__init__() self.save_hyperparameters() + self.class_to_idx = {c: i for i, c in enumerate(class_names)} + + def _load(self, parquet_dir: str) -> "Dataset": # type: ignore[name-defined] + files = sorted(str(p) for p in Path(parquet_dir).glob("*.parquet")) + ds = load_dataset("parquet", data_files=files, split="train") + if self.hparams.tissue_prop_min > 0: + ds = ds.filter( + lambda r: r["tissue_prop"] >= self.hparams.tissue_prop_min + ) + ds = ds.map(lambda r: {"y": self.class_to_idx[r["label"]]}) + return ds def setup(self, stage: str) -> None: - full_train = EmbeddingsDataset( - self.hparams.train_dir, - self.hparams.class_names, - self.hparams.coverage_prefix, - ) - n_val = int(len(full_train) * self.hparams.val_fraction) - n_train = len(full_train) - n_val - self.train_set, self.val_set = torch.utils.data.random_split( - full_train, [n_train, n_val], generator=torch.Generator().manual_seed(0) - ) - self.test_set = EmbeddingsDataset( - self.hparams.test_dir, - self.hparams.class_names, - self.hparams.coverage_prefix, + train_full = self._load(self.hparams.train_dir) + self.train_set = train_full.filter( + lambda r: r["fold"] != self.hparams.val_fold + ).with_format("torch", columns=["embedding", "y"]) + self.val_set = train_full.filter( + lambda r: r["fold"] == self.hparams.val_fold + ).with_format("torch", columns=["embedding", "y"]) + self.test_set = self._load(self.hparams.test_dir).with_format( + "torch", columns=["embedding", "y"] ) - def train_dataloader(self) -> DataLoader: + @staticmethod + def _collate(batch: list[dict]) -> tuple: + import torch + + x = torch.stack([b["embedding"].float() for b in batch]) + y = torch.stack([b["y"].long() for b in batch]) + return x, y + + def _loader(self, ds, shuffle: bool) -> DataLoader: return DataLoader( - self.train_set, + ds, batch_size=self.hparams.batch_size, - shuffle=True, + shuffle=shuffle, num_workers=self.hparams.num_workers, + collate_fn=self._collate, ) + def train_dataloader(self) -> DataLoader: + return self._loader(self.train_set, shuffle=True) + def val_dataloader(self) -> DataLoader: - return DataLoader( - self.val_set, - batch_size=self.hparams.batch_size, - num_workers=self.hparams.num_workers, - ) + return self._loader(self.val_set, shuffle=False) def test_dataloader(self) -> DataLoader: - return DataLoader( - self.test_set, - batch_size=self.hparams.batch_size, - num_workers=self.hparams.num_workers, - ) + return self._loader(self.test_set, shuffle=False) From c3ef38a47d804457be0c69589c31d83eb9f05127 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 7 May 2026 23:16:30 +0200 Subject: [PATCH 003/107] feat(ml): wire up linear probe training with k-fold CV on cached embeddings Datamodule downloads embeddings + kfold artifacts from MLflow, joins on (slide_id, x, y) via pyarrow, applies class mapping, tissue/class coverage filters, and exposes per-fold splits via set_val_fold(). Training script loops folds in a single run and logs per-fold + aggregate metrics. Probe adds per-class F1, confusion matrix figures, optional input L2-norm and class weights. Co-Authored-By: Claude Sonnet 4.6 --- .../collapse_alterations_to_other.yaml | 9 + configs/class_mapping/standard.yaml | 11 + configs/data/dataset.yaml | 2 + ...r_probe_collapse_alterations_to_other.yaml | 11 + configs/ml/linear_probe.yaml | 30 ++- ml/PLAN_LINEAR_PROBE.md | 186 ++++++++++++++ ml/data/embeddings_datamodule.py | 234 +++++++++++++++--- ml/models/linear_probe.py | 100 ++++++-- ml/train.py | 50 +++- scripts/submit_linear_probe.py | 18 ++ 10 files changed, 573 insertions(+), 78 deletions(-) create mode 100644 configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml create mode 100644 ml/PLAN_LINEAR_PROBE.md create mode 100644 scripts/submit_linear_probe.py diff --git a/configs/class_mapping/collapse_alterations_to_other.yaml b/configs/class_mapping/collapse_alterations_to_other.yaml index 160aad92..66508f1f 100644 --- a/configs/class_mapping/collapse_alterations_to_other.yaml +++ b/configs/class_mapping/collapse_alterations_to_other.yaml @@ -42,3 +42,12 @@ class_indices: Epithelium: 4 Muscle: 5 Other: 6 + +class_names: + - Nerve + - Blood + - Connective-Tissue + - Fat + - Epithelium + - Muscle + - Other diff --git a/configs/class_mapping/standard.yaml b/configs/class_mapping/standard.yaml index 39866e3c..a623c14e 100644 --- a/configs/class_mapping/standard.yaml +++ b/configs/class_mapping/standard.yaml @@ -46,3 +46,14 @@ class_indices: Inflammation-Chronic: 6 Necrosis: 7 Neoplastic: 8 + +class_names: + - Nerve + - Blood + - Connective-Tissue + - Fat + - Epithelium + - Muscle + - Inflammation-Chronic + - Necrosis + - Neoplastic diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 0303fc23..ea2ebfad 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -15,6 +15,8 @@ dataset: tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" + embeddings_run_id: "f05076dcd5e64cb2839efe5fb20a22ae" + kfold_run_id: "2e81b0597b614ba8b675e3b34528c1df" exclusions: bad_slides: diff --git a/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml b/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml new file mode 100644 index 00000000..74c089ab --- /dev/null +++ b/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml @@ -0,0 +1,11 @@ +# @package _global_ + +defaults: + - /data: dataset + - /class_mapping: collapse_alterations_to_other + - _self_ + +embeddings_run_id: ${dataset.mlflow_artifacts.embeddings_run_id} +kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} +embed_dim: 2560 +n_folds: 5 diff --git a/configs/ml/linear_probe.yaml b/configs/ml/linear_probe.yaml index c8697fc5..7b969c62 100644 --- a/configs/ml/linear_probe.yaml +++ b/configs/ml/linear_probe.yaml @@ -1,12 +1,14 @@ # @package _global_ defaults: - - /class_mapping: standard + - /class_mapping: collapse_alterations_to_other mode: fit -embed_dim: ??? +embed_dim: 2560 embeddings_run_id: ??? +kfold_run_id: ??? +n_folds: 5 trainer: _target_: rationai.mlkit.lightning.Trainer @@ -15,17 +17,18 @@ trainer: devices: 1 log_every_n_steps: 50 -val_fold: 0 - data: _target_: ml.data.embeddings_datamodule.EmbeddingsDataModule - train_dir: ${project_path}/embeddings/${embeddings_run_id}/train/tiles - test_dir: ${project_path}/embeddings/${embeddings_run_id}/test/tiles - class_names: ${class_mapping.class_names} - val_fold: ${val_fold} + embeddings_run_id: ${embeddings_run_id} + kfold_run_id: ${kfold_run_id} + kfold_artifact_path: kfold_split/kfold_tiles.parquet + class_mapping: ${class_mapping.class_mapping} + class_indices: ${class_mapping.class_indices} + drop_unmapped: true + tissue_prop_min: 0.0 + class_coverage_min: 0.0 batch_size: 1024 num_workers: 4 - tissue_prop_min: 0.0 model: _target_: ml.models.linear_probe.LinearProbe @@ -33,12 +36,15 @@ model: num_classes: ${len:${class_mapping.class_names}} lr: 1e-3 weight_decay: 0.0 + class_names: ${class_mapping.class_names} metadata: - run_name: "Linear probe (embed_dim=${embed_dim})" - description: Linear probe on cached embeddings + run_name: "Linear probe (embed=${embeddings_run_id}, kfold=${kfold_run_id})" + description: Linear probe on cached Virchow2 embeddings, k-fold CV hyperparams: embed_dim: ${embed_dim} lr: ${model.lr} batch_size: ${data.batch_size} - val_fold: ${val_fold} + n_folds: ${n_folds} + tissue_prop_min: ${data.tissue_prop_min} + class_coverage_min: ${data.class_coverage_min} diff --git a/ml/PLAN_LINEAR_PROBE.md b/ml/PLAN_LINEAR_PROBE.md new file mode 100644 index 00000000..3c6586e8 --- /dev/null +++ b/ml/PLAN_LINEAR_PROBE.md @@ -0,0 +1,186 @@ +# Linear-Probe Training: Implementation Plan + +Scope of the first PR: **train + k-fold validation only**. Test-set evaluation is a follow-up PR (separate entrypoint, no fold loop, possibly slide-level aggregation). Keeping test out of this PR keeps the held-out set untouched while we tune the probe. + +## 0. Current state (what already exists) + +- `ml/train.py` — Lightning entrypoint, `mode={fit,test}`, instantiates `data` / `model` / `trainer` from Hydra. +- `ml/data/embeddings_datamodule.py` — `EmbeddingsDataModule` that loads `train_dir` / `test_dir` parquet via `datasets.load_dataset`, filters by `fold`, maps `label` → idx. +- `ml/models/linear_probe.py` — `nn.Linear` head, CE loss, accuracy + macro-F1 (`torchmetrics`). +- `configs/ml/linear_probe.yaml` — wires the above; assumes a single parquet dir per split with `embedding`, `label`, `fold`, `tissue_prop` already joined. +- `configs/class_mapping/collapse_alterations_to_other.yaml` — **the mapping we will use for training**: 7 classes (Nerve, Blood, Connective-Tissue, Fat, Epithelium, Muscle, Other). Inflammation/necrosis/neoplastic alterations are collapsed into `Other`. `standard.yaml` is the alternate 9-class mapping; not used here. + +## 1. Gaps to close before this works end-to-end + +### 1.1 Embeddings and labels are not in the same parquet +`embed.py` writes `train/tiles/*.parquet` with columns `slide_id, x, y, embedding`. `kfold_split.py` writes one `kfold_tiles.parquet` with columns `slide_id, x, y, label, tissue_prop, fold` (+ `roi_coverage_*`). The datamodule today assumes everything is in one file. → must **join on `(slide_id, x, y)`**. + +**Recommendation:** do the join lazily in `setup()` using `pyarrow` / `duckdb` over the parquet files (no extra preprocessing script, no extra MLflow run). The labels parquet is small enough to fit in RAM; embeddings stay memory-mapped. + +### 1.2 Inputs come from MLflow artifacts, not local disk +Current config hardcodes `${project_path}/embeddings/${embeddings_run_id}/...`. The other scripts (`embed.py`, `kfold_split.py`) consistently use `mlflow.artifacts.download_artifacts(run_id=..., artifact_path=...)`. → datamodule should accept `embeddings_run_id` + `kfold_run_id` and download to a cache dir on `prepare_data()` (single-process hook). + +### 1.3 Raw labels in kfold parquet are not the 9 canonical classes +`kfold_split.py` writes `label = roi_coverage_` argmax → values like `"EPITHELIUM-BB"`, `"NEOPLASTIC-MALIGNANT"`, or `"background"`. The probe expects canonical names (`Epithelium`, `Neoplastic`, …). → apply `class_mapping` (raw → canonical) inside the datamodule. Tiles whose raw label isn't covered by the mapping (today: `"background"`) need a policy — see §1.4. + +### 1.4 Background and coverage-threshold filtering +**Background**: `collapse_alterations_to_other.yaml` has no Background class. `filter_tiles.py` already drops tiles with zero tissue coverage and zero annotation coverage upstream — so by the time we reach the kfold parquet, `"background"` rows can still appear (a tile can have tissue but no annotation overlap, or vice-versa, depending on the filter logic). Drop any rows whose raw label isn't in the mapping. Config knob: `drop_unmapped: true` (default true). + +**Coverage thresholds (live in the datamodule, not a separate PR/script)**: the upstream filter is the coarse cleaning step (any tissue, any annotation). For training-time experimentation, expose two filters as datamodule knobs and apply them after the join, before the train/val split: + +- `tissue_prop_min: float = 0.0` — drop tiles whose total annotation coverage `tissue_prop` is below the threshold. This already exists as a field on the datamodule; keep it. +- `class_coverage_min: float = 0.0` — drop tiles whose **dominant class coverage** (i.e. the `roi_coverage_*` value backing the assigned label, after collapsing per the class mapping) is below the threshold. Forces the label to be "confident" — useful when many tiles are mosaics. + +Both are pure row masks on the labels DataFrame, cheap, and get logged as MLflow params with the run, so threshold sweeps show up cleanly. Rationale for not making this a separate preprocessing PR: these thresholds are experimental knobs you'll sweep alongside LR / weight decay; locking them into a parquet artifact would force a re-preprocessing run for every variant. The fundamental cleaning (any-tissue, any-annotation) stays where it belongs in `filter_tiles.py`. + +To support `class_coverage_min` after class collapsing, the datamodule needs the per-class collapsed coverage. Compute it in `setup()`: for each canonical class C, sum the `roi_coverage_` columns whose raw label maps to C. Then the dominant-class coverage for a tile is `max_C(collapsed_coverage_C)`. The kfold parquet already carries `roi_coverage_*` columns, so no new artifact is needed. + +### 1.5 Config bugs +- `configs/ml/linear_probe.yaml:30` uses `class_mapping.class_names`, which doesn't exist in `standard.yaml`. Either (a) add a `class_names` list to the class-mapping yaml derived from the dict keys, or (b) change the reference. Pick (a) — most readable. +- `${len:...}` resolver — verify it's registered (rationai.mlkit likely does, but confirm by running once). + +### 1.6 K-fold orchestration +User wants **all folds in one MLflow run**. Today `train.py` runs a single fold (`val_fold` param). → wrap fit in a loop over folds, log per-fold metrics under `fold_{i}/...` and write aggregate (`val/acc_mean`, `val/acc_std`, `val/f1_macro_mean`, …) at the end. + +### 1.7 `trainer.test()` after fit +`train.py:21` calls `trainer.test(...)` after `trainer.fit(...)`. Remove for this PR (no test in this PR). Re-introduce in the test-PR, in a separate `mode=test` path that doesn't loop folds. + +--- + +## 2. Concrete step-by-step plan + +### Step 1 — Fix `configs/class_mapping/collapse_alterations_to_other.yaml` +Add a derived `class_names` list (so configs that reference `class_mapping.class_names` work) and switch `linear_probe.yaml` to default to this mapping: +```yaml +class_names: + - Nerve + - Blood + - Connective-Tissue + - Fat + - Epithelium + - Muscle + - Other +``` +Apply the same change to `standard.yaml` for consistency, but `linear_probe.yaml`'s `defaults` should point at `collapse_alterations_to_other`. Keep `class_mapping` (canonical→raw list) and `class_indices` as they are. + +### Step 2 — Rewrite `EmbeddingsDataModule` +Responsibilities, in order: + +1. **`prepare_data()`** (single-process): + - Download embeddings artifact: `mlflow.artifacts.download_artifacts(run_id=embeddings_run_id, artifact_path="train")`. Cache path on `self`. + - Download kfold artifact: `mlflow.artifacts.download_artifacts(run_id=kfold_run_id, artifact_path="/kfold_tiles.parquet")`. +2. **`setup(stage)`**: + - Read kfold parquet into pandas (small: ~few M rows × handful of cols). **Keep** `roi_coverage_*` columns until thresholds are applied. + - Build a `raw → canonical` lookup from the config's `class_mapping` (dict of canonical → list[raw]) and apply it to the `label` column. + - Drop rows whose raw label isn't in the mapping (handles `"background"` and any stragglers) — gated by `drop_unmapped: true`. + - Compute per-tile **collapsed coverage**: for each canonical class C, sum `roi_coverage_` over its raw members. Add a `dominant_coverage` column = the collapsed coverage of the assigned canonical label. + - Apply `tissue_prop_min` and `class_coverage_min` row masks. Log row-count deltas at each step (initial → after raw-label drop → after `tissue_prop_min` → after `class_coverage_min`) as MLflow metrics so threshold sweeps are interpretable. + - Drop `roi_coverage_*` columns once thresholds are done. + - Load embeddings as an Arrow table: `pyarrow.dataset.dataset(emb_dir, format="parquet").to_table(columns=["slide_id","x","y","embedding"])`. Memory-mapped, zero-copy. + - **Join** on `(slide_id, x, y)` via `pyarrow.Table.join(labels_table, keys=["slide_id","x","y"], join_type="inner")`. The two parquets share this key by construction (both downstream of `filter_tiles/train_tiles.parquet`, neither remaps coords), so the inner-join is effectively 1:1. Use `pyarrow.Table.join` rather than `pandas.merge` — the embedding column is heavy (~2560 × 4B × N) and we want to avoid copies. Wrap the joined Arrow table back into `datasets.Dataset(arrow_table=...)`. + - **Verify the join**: log `n_embeddings`, `n_labels`, `n_joined` as MLflow metrics. If `n_joined < n_labels`, log a warning with the gap — it means the embed run dropped tiles (e.g., upstream API failures past retries) and you'll want to know. + - Map `label` → `y` (int) using `class_indices` (or `class_names.index(label)`). + - For the configured `val_fold`, split into `train_set` (`fold != val_fold`) and `val_set` (`fold == val_fold`). `with_format("torch", columns=["embedding", "y"])`. +3. **`train_dataloader` / `val_dataloader`**: as today. Drop `test_dataloader` / `test_dir` arg in this PR (or leave the arg optional `test_dir: Optional[str] = None` with a `NotImplementedError` if requested — cleaner to just remove until the test PR adds it). +4. Make `val_fold` a settable attribute (not just hparam) so the train script can rebuild the data split per fold without reloading the parquet: + - Cache the joined `full_dataset` on the datamodule. + - Expose a `set_val_fold(fold: int)` that re-derives `train_set` / `val_set` by filtering — this avoids re-downloading and re-joining N times. + +### Step 3 — K-fold loop in `ml/train.py` +Refactor `main()` for `mode == "fit"`: +```python +datamodule = instantiate(config.data) +datamodule.prepare_data() +datamodule.setup("fit") # builds full_dataset once + +per_fold_metrics: list[dict] = [] +for fold in range(config.n_folds): + pl.seed_everything(config.seed + fold) # fresh init per fold + datamodule.set_val_fold(fold) + model = instantiate(config.model) + trainer = instantiate(config.trainer, logger=logger) + trainer.fit(model, datamodule=datamodule) + # collect last-epoch val metrics + per_fold_metrics.append({k: float(v) for k, v in trainer.callback_metrics.items() + if k.startswith("val/")}) + # log per-fold + for k, v in per_fold_metrics[-1].items(): + mlflow.log_metric(f"fold_{fold}/{k}", v) + +# aggregate +import numpy as np +keys = per_fold_metrics[0].keys() +for k in keys: + vals = np.array([m[k] for m in per_fold_metrics]) + mlflow.log_metric(f"{k}_mean", vals.mean()) + mlflow.log_metric(f"{k}_std", vals.std()) +``` +Note: `trainer.test()` removed. Keep `mode == "test"` as `raise NotImplementedError("Test mode arrives in the test-set PR")` for now — clearer than silently breaking. + +### Step 4 — Update `configs/ml/linear_probe.yaml` +- Replace `embeddings_run_id` with two run-id fields: + ```yaml + embeddings_run_id: ??? # e.g. f05076dcd5e64cb2839efe5fb20a22ae + kfold_run_id: ??? # e.g. 2e81b0597b614ba8b675e3b34528c1df + ``` +- Add `n_folds: 5` (or wire to read from kfold artifact metadata if available). +- Drop `test_dir` from `data:` block. +- Update `data:` to: + ```yaml + data: + _target_: ml.data.embeddings_datamodule.EmbeddingsDataModule + embeddings_run_id: ${embeddings_run_id} + kfold_run_id: ${kfold_run_id} + kfold_artifact_path: kfold_split/kfold_tiles.parquet # confirm against logged path + class_mapping: ${class_mapping.class_mapping} + class_indices: ${class_indices} + drop_unmapped: true + tissue_prop_min: 0.0 # threshold sweep knob + class_coverage_min: 0.0 # threshold sweep knob + batch_size: 1024 + num_workers: 4 + ``` +- The `defaults` block should select `collapse_alterations_to_other` for `class_mapping`. +- Remove the `val_fold` field at the global level — folds are now driven by the loop. + +### Step 5 — Strengthen `LinearProbe` +Small additions, low risk: +- Per-class F1 logged at `validation_epoch_end`: `F1Score(..., average=None)` → log each class as `val/f1_`. +- Confusion matrix (`torchmetrics.ConfusionMatrix`) computed at end of validation, logged as an MLflow figure or table per fold. +- Optional embedding L2-normalization toggle (`normalize_input: bool`) — Virchow2 outputs are typically not L2-normalized at the CLS-token stage; making this a flag is one line and a common probe variant to try. +- Optional class weights for CE (defer wiring, but leave a `class_weights: Optional[list[float]] = None` parameter). + +### Step 6 — Logging hygiene +- Log artifacts: the resolved class list, the join coverage stats (`#tiles in embeddings`, `#tiles in kfold`, `#joined`, `#dropped_background`, `#dropped_no_label`). +- Log a one-row summary table per fold: `n_train`, `n_val`, label distribution. +- Set `metadata.run_name` to include both run-ids: `"Linear probe (embed=${embeddings_run_id[:8]}, kfold=${kfold_run_id[:8]})"`. + +### Step 7 — Smoke run +End-to-end against the existing artifacts: +``` +embeddings_run_id=f05076dcd5e64cb2839efe5fb20a22ae +kfold_run_id=2e81b0597b614ba8b675e3b34528c1df +embed_dim= +n_folds=5 # confirm against the kfold run's params +``` +Run for 1–2 epochs first to confirm wiring, then full `max_epochs=30`. + +--- + +## 3. Resolved decisions + +1. **Virchow2 embedding dimension** — 2560. +2. **Kfold artifact path** — `kfold_split/kfold_tiles.parquet`. +3. **n_folds** — 5. +4. **Validation cadence** — fit one fold, then move on (sequential). +5. **Reproducibility of `set_val_fold`** — confirmed: re-instantiate model + seed-per-fold (`seed + fold`). + +--- + +## 4. Out of scope for this PR (next PR) + +- Test-set evaluation (single pass, no folds, possibly with slide-level aggregation). +- Fine-tuning beyond the linear head. +- Class-weighted CE / focal loss / soft-label CE on `roi_coverage` proportions. +- Model selection across folds (best ckpt per fold, ensemble at test time). +- Multi-GPU / DDP — single GPU is plenty for a linear probe on cached embeddings. diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py index 8175cec7..3e7d340a 100644 --- a/ml/data/embeddings_datamodule.py +++ b/ml/data/embeddings_datamodule.py @@ -1,75 +1,231 @@ +import logging +import warnings from pathlib import Path +from typing import Any import lightning as pl -from datasets import load_dataset +import mlflow +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.dataset as pad +from datasets import Dataset +from omegaconf import OmegaConf from torch.utils.data import DataLoader +log = logging.getLogger(__name__) + + class EmbeddingsDataModule(pl.LightningDataModule): - """Linear-probe data module backed by HuggingFace datasets. + """Linear-probe data module. - Expects parquet files with columns: `embedding` (list[float]), `label` (str), - `fold` (int, train side only), and `tissue_prop` (float). + Downloads embeddings and kfold-split artifacts from MLflow, joins them on + (slide_id, x, y), applies class mapping and coverage filters, and exposes + train/val splits per fold via set_val_fold(). """ def __init__( self, - train_dir: str, - test_dir: str, - class_names: list[str], - val_fold: int = 0, + embeddings_run_id: str, + kfold_run_id: str, + class_mapping: dict[str, list[str]], + class_indices: dict[str, int], + kfold_artifact_path: str = "kfold_split/kfold_tiles.parquet", + drop_unmapped: bool = True, + tissue_prop_min: float = 0.0, + class_coverage_min: float = 0.0, batch_size: int = 1024, num_workers: int = 4, - tissue_prop_min: float = 0.0, ) -> None: super().__init__() self.save_hyperparameters() - self.class_to_idx = {c: i for i, c in enumerate(class_names)} - - def _load(self, parquet_dir: str) -> "Dataset": # type: ignore[name-defined] - files = sorted(str(p) for p in Path(parquet_dir).glob("*.parquet")) - ds = load_dataset("parquet", data_files=files, split="train") - if self.hparams.tissue_prop_min > 0: - ds = ds.filter( - lambda r: r["tissue_prop"] >= self.hparams.tissue_prop_min - ) - ds = ds.map(lambda r: {"y": self.class_to_idx[r["label"]]}) - return ds + self.embeddings_run_id = embeddings_run_id + self.kfold_run_id = kfold_run_id + self.kfold_artifact_path = kfold_artifact_path + self.drop_unmapped = drop_unmapped + self.tissue_prop_min = tissue_prop_min + self.class_coverage_min = class_coverage_min + self.batch_size = batch_size + self.num_workers = num_workers + cm: Any = ( + OmegaConf.to_container(class_mapping, resolve=True) + if OmegaConf.is_config(class_mapping) + else class_mapping + ) + ci: Any = ( + OmegaConf.to_container(class_indices, resolve=True) + if OmegaConf.is_config(class_indices) + else class_indices + ) + self._class_mapping: dict[str, list[str]] = dict(cm) + self._raw_to_canonical: dict[str, str] = { + raw: canonical + for canonical, raws in self._class_mapping.items() + for raw in raws + } + self._class_indices: dict[str, int] = dict(ci) + self._emb_dir: Path | None = None + self._kfold_path: Path | None = None + self.full_dataset: Dataset | None = None + self.train_set: Dataset | None = None + self.val_set: Dataset | None = None + self._val_fold: int = 0 + self._fold_array: np.ndarray | None = None + + # ------------------------------------------------------------------ + # prepare_data - single-process MLflow download + # ------------------------------------------------------------------ + + def prepare_data(self) -> None: + emb_local = mlflow.artifacts.download_artifacts( + run_id=self.embeddings_run_id, + artifact_path="train", + ) + self._emb_dir = Path(emb_local) + + kfold_local = mlflow.artifacts.download_artifacts( + run_id=self.kfold_run_id, + artifact_path=self.kfold_artifact_path, + ) + self._kfold_path = Path(kfold_local) + + # ------------------------------------------------------------------ + # setup - join, filter, build full_dataset once + # ------------------------------------------------------------------ def setup(self, stage: str) -> None: - train_full = self._load(self.hparams.train_dir) - self.train_set = train_full.filter( - lambda r: r["fold"] != self.hparams.val_fold - ).with_format("torch", columns=["embedding", "y"]) - self.val_set = train_full.filter( - lambda r: r["fold"] == self.hparams.val_fold - ).with_format("torch", columns=["embedding", "y"]) - self.test_set = self._load(self.hparams.test_dir).with_format( - "torch", columns=["embedding", "y"] + if self.full_dataset is not None: + return # already built; use set_val_fold() to change splits + + assert self._emb_dir is not None and self._kfold_path is not None, ( + "Call prepare_data() before setup()" ) + # --- load kfold labels (small) --- + labels_df = pd.read_parquet(self._kfold_path) + n_initial = len(labels_df) + roi_cols = [c for c in labels_df.columns if c.startswith("roi_coverage_")] + + # --- apply raw→canonical label mapping --- + labels_df["label"] = labels_df["label"].map(self._raw_to_canonical) + if self.drop_unmapped: + n_before = len(labels_df) + labels_df = labels_df[labels_df["label"].notna()].copy() + dropped = n_before - len(labels_df) + if dropped: + log.warning("Dropping %d tiles with unmapped labels", dropped) + mlflow.log_metric("n_tiles_initial", n_initial) + mlflow.log_metric("n_tiles_after_label_map", len(labels_df)) + + # --- tissue_prop filter --- + if self.tissue_prop_min > 0.0: + labels_df = labels_df[ + labels_df["tissue_prop"] >= self.tissue_prop_min + ].copy() + mlflow.log_metric("n_tiles_after_tissue_prop_filter", len(labels_df)) + + # --- compute dominant_coverage (coverage of the assigned canonical class) --- + dom_cov = np.zeros(len(labels_df), dtype=np.float32) + label_arr = labels_df["label"].to_numpy() + for canonical, raws in self._class_mapping.items(): + cols = [ + f"roi_coverage_{r}" + for r in raws + if f"roi_coverage_{r}" in labels_df.columns + ] + mask: np.ndarray = label_arr == canonical + if cols and mask.any(): + dom_cov[mask] = labels_df.loc[mask, cols].sum(axis=1).to_numpy() + labels_df["dominant_coverage"] = dom_cov + + if self.class_coverage_min > 0.0: + labels_df = labels_df[ + labels_df["dominant_coverage"] >= self.class_coverage_min + ].copy() + mlflow.log_metric("n_tiles_after_class_coverage_filter", len(labels_df)) + + labels_df = labels_df.drop( + columns=[*roi_cols, "dominant_coverage"], errors="ignore" + ) + + # --- map label → integer class index (column name avoids collision with tile y-coord) --- + labels_df["target"] = labels_df["label"].map(self._class_indices).astype(int) + + # --- load embeddings (memory-mapped Arrow) --- + emb_files = sorted(self._emb_dir.rglob("*.parquet")) + emb_table = pad.dataset([str(f) for f in emb_files], format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) + mlflow.log_metric("n_embeddings", len(emb_table)) + + # --- inner-join on (slide_id, x, y) --- + labels_table = pa.Table.from_pandas( + labels_df[["slide_id", "x", "y", "fold", "tissue_prop", "label", "target"]], + preserve_index=False, + ) + mlflow.log_metric("n_labels", len(labels_table)) + + joined = emb_table.join( + labels_table, + keys=["slide_id", "x", "y"], + join_type="inner", + ) + n_joined = len(joined) + mlflow.log_metric("n_joined", n_joined) + + gap = len(labels_table) - n_joined + if gap > 0: + warnings.warn( + f"Join gap: {gap} label tiles have no matching embedding " + "(embed run may have dropped tiles due to upstream failures).", + stacklevel=2, + ) + + self.full_dataset = Dataset(arrow_table=joined) + self._fold_array = np.asarray(joined.column("fold"), dtype=np.int64) + self.set_val_fold(self._val_fold) + + # ------------------------------------------------------------------ + # fold management + # ------------------------------------------------------------------ + + def set_val_fold(self, fold: int) -> None: + self._val_fold = fold + if self.full_dataset is None or self._fold_array is None: + return + train_idx = np.flatnonzero(self._fold_array != fold).tolist() + val_idx = np.flatnonzero(self._fold_array == fold).tolist() + self.train_set = self.full_dataset.select(train_idx).with_format( + "torch", columns=["embedding", "target"] + ) + self.val_set = self.full_dataset.select(val_idx).with_format( + "torch", columns=["embedding", "target"] + ) + + # ------------------------------------------------------------------ + # dataloaders + # ------------------------------------------------------------------ + @staticmethod - def _collate(batch: list[dict]) -> tuple: + def _collate(batch: list[dict[str, Any]]) -> tuple[Any, Any]: import torch x = torch.stack([b["embedding"].float() for b in batch]) - y = torch.stack([b["y"].long() for b in batch]) + y = torch.stack([b["target"].long() for b in batch]) return x, y - def _loader(self, ds, shuffle: bool) -> DataLoader: + def _loader(self, ds: Dataset, shuffle: bool) -> DataLoader[Any]: return DataLoader( ds, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size, shuffle=shuffle, - num_workers=self.hparams.num_workers, + num_workers=self.num_workers, collate_fn=self._collate, ) - def train_dataloader(self) -> DataLoader: + def train_dataloader(self) -> DataLoader[Any]: return self._loader(self.train_set, shuffle=True) - def val_dataloader(self) -> DataLoader: + def val_dataloader(self) -> DataLoader[Any]: return self._loader(self.val_set, shuffle=False) - - def test_dataloader(self) -> DataLoader: - return self._loader(self.test_set, shuffle=False) diff --git a/ml/models/linear_probe.py b/ml/models/linear_probe.py index 667df1c6..710d6ab5 100644 --- a/ml/models/linear_probe.py +++ b/ml/models/linear_probe.py @@ -1,8 +1,11 @@ +from typing import cast + import lightning as pl +import mlflow import torch import torch.nn.functional as F from torch import nn, optim -from torchmetrics import Accuracy, F1Score, MetricCollection +from torchmetrics import Accuracy, ConfusionMatrix, F1Score, MetricCollection class LinearProbe(pl.LightningModule): @@ -10,14 +13,23 @@ def __init__( self, embed_dim: int, num_classes: int, + class_names: list[str] | None = None, lr: float = 1e-3, weight_decay: float = 0.0, + normalize_input: bool = False, + class_weights: list[float] | None = None, ) -> None: super().__init__() self.save_hyperparameters() + self.embed_dim = embed_dim + self.num_classes = num_classes + self.class_names = class_names + self.lr = lr + self.weight_decay = weight_decay + self.normalize_input = normalize_input self.head = nn.Linear(embed_dim, num_classes) - metrics = MetricCollection( + base_metrics = MetricCollection( { "acc": Accuracy(task="multiclass", num_classes=num_classes), "f1_macro": F1Score( @@ -25,33 +37,85 @@ def __init__( ), } ) - self.train_metrics = metrics.clone(prefix="train/") - self.val_metrics = metrics.clone(prefix="val/") - self.test_metrics = metrics.clone(prefix="test/") + self.train_metrics = base_metrics.clone(prefix="train/") + self.val_metrics = base_metrics.clone(prefix="val/") + + self.val_f1_per_class = F1Score( + task="multiclass", num_classes=num_classes, average=None + ) + self.val_conf_matrix = ConfusionMatrix( + task="multiclass", num_classes=num_classes + ) + + if class_weights is not None: + self.register_buffer( + "class_weights", torch.tensor(class_weights, dtype=torch.float) + ) + else: + self.class_weights = None def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.normalize_input: + x = F.normalize(x, dim=-1) return self.head(x) - def _step(self, batch, metrics, log_prefix: str) -> torch.Tensor: + def training_step( + self, batch: tuple[torch.Tensor, torch.Tensor], _: int + ) -> torch.Tensor: x, y = batch logits = self(x) - loss = F.cross_entropy(logits, y) - self.log(f"{log_prefix}/loss", loss, prog_bar=True) - self.log_dict(metrics(logits, y), prog_bar=True) + loss = F.cross_entropy(logits, y, weight=self.class_weights) + self.log("train/loss", loss, prog_bar=True) + self.log_dict(self.train_metrics(logits, y), prog_bar=True) return loss - def training_step(self, batch, _): - return self._step(batch, self.train_metrics, "train") + def validation_step( + self, batch: tuple[torch.Tensor, torch.Tensor], _: int + ) -> torch.Tensor: + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y, weight=self.class_weights) + self.log("val/loss", loss, prog_bar=True) + self.log_dict(self.val_metrics(logits, y), prog_bar=True) + self.val_f1_per_class.update(logits, y) + self.val_conf_matrix.update(logits, y) + return loss + + def on_validation_epoch_end(self) -> None: + f1_per_class = cast("torch.Tensor", self.val_f1_per_class.compute()) + class_names = self.class_names or [str(i) for i in range(self.num_classes)] + for name, f1 in zip(class_names, f1_per_class, strict=True): + self.log(f"val/f1_{name}", f1) + + conf_mat = cast("torch.Tensor", self.val_conf_matrix.compute()).cpu().numpy() + try: + import matplotlib.pyplot as plt + import numpy as np - def validation_step(self, batch, _): - return self._step(batch, self.val_metrics, "val") + fig, ax = plt.subplots(figsize=(8, 7)) + im = ax.imshow(conf_mat, interpolation="nearest") + ax.set( + xticks=np.arange(len(class_names)), + yticks=np.arange(len(class_names)), + xticklabels=class_names, + yticklabels=class_names, + xlabel="Predicted", + ylabel="True", + ) + plt.setp(ax.get_xticklabels(), rotation=45, ha="right") + fig.colorbar(im, ax=ax) + fig.tight_layout() + mlflow.log_figure(fig, f"confusion_matrix_epoch{self.current_epoch}.png") + plt.close(fig) + except Exception: + pass - def test_step(self, batch, _): - return self._step(batch, self.test_metrics, "test") + self.val_f1_per_class.reset() + self.val_conf_matrix.reset() - def configure_optimizers(self): + def configure_optimizers(self) -> optim.Optimizer: return optim.AdamW( self.parameters(), - lr=self.hparams.lr, - weight_decay=self.hparams.weight_decay, + lr=self.lr, + weight_decay=self.weight_decay, ) diff --git a/ml/train.py b/ml/train.py index 0cb3213b..0e1f1ccc 100644 --- a/ml/train.py +++ b/ml/train.py @@ -1,29 +1,61 @@ +from typing import TYPE_CHECKING, Any + import hydra import lightning as pl +import mlflow +import numpy as np from hydra.utils import instantiate from omegaconf import DictConfig from rationai.mlkit import autolog, with_cli_args from rationai.mlkit.lightning.loggers import MLFlowLogger +if TYPE_CHECKING: + from ml.data.embeddings_datamodule import EmbeddingsDataModule + + @with_cli_args(["+ml=linear_probe"]) @hydra.main(config_path="../configs", config_name="ml", version_base=None) @autolog def main(config: DictConfig, logger: MLFlowLogger) -> None: - pl.seed_everything(config.seed) - - datamodule: pl.LightningDataModule = instantiate(config.data) - model: pl.LightningModule = instantiate(config.model) - trainer: pl.Trainer = instantiate(config.trainer, logger=logger) - if config.mode == "fit": - trainer.fit(model, datamodule=datamodule, ckpt_path=config.checkpoint) - trainer.test(model, datamodule=datamodule) + _fit(config, logger) elif config.mode == "test": - trainer.test(model, datamodule=datamodule, ckpt_path=config.checkpoint) + raise NotImplementedError("Test mode arrives in the test-set PR") else: raise ValueError(f"Unknown mode: {config.mode}") +def _fit(config: DictConfig, logger: MLFlowLogger) -> None: + datamodule: EmbeddingsDataModule = instantiate(config.data) + datamodule.prepare_data() + datamodule.setup("fit") + + per_fold_metrics: list[dict[str, Any]] = [] + for fold in range(config.n_folds): + pl.seed_everything(config.seed + fold) + datamodule.set_val_fold(fold) + + model: pl.LightningModule = instantiate(config.model) + trainer: pl.Trainer = instantiate(config.trainer, logger=logger) + trainer.fit(model, datamodule=datamodule) + + fold_metrics = { + k: float(v) + for k, v in trainer.callback_metrics.items() + if k.startswith("val/") + } + per_fold_metrics.append(fold_metrics) + for k, v in fold_metrics.items(): + mlflow.log_metric(f"fold_{fold}/{k}", v) + + if per_fold_metrics: + keys = per_fold_metrics[0].keys() + for k in keys: + vals = np.array([m[k] for m in per_fold_metrics]) + mlflow.log_metric(f"{k}_mean", float(vals.mean())) + mlflow.log_metric(f"{k}_std", float(vals.std())) + + if __name__ == "__main__": main() diff --git a/scripts/submit_linear_probe.py b/scripts/submit_linear_probe.py new file mode 100644 index 00000000..bca9b9c8 --- /dev/null +++ b/scripts/submit_linear_probe.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-linear-probe", + username=..., + cpu=4, + memory="32Gi", + gpu="A40", + public=False, + script=[ + "git clone https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml.train +ml=... +experiment=...", + ], + storage=[storage.secure.PROJECTS], +) From c644f22c6cb48e414002e87189d3709e376e943b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 7 May 2026 23:40:42 +0200 Subject: [PATCH 004/107] fix(configs): use override for class_mapping in experiment yaml The experiment file was declaring /class_mapping as a fresh default while configs/ml/linear_probe.yaml already had one, which Hydra rejects as a duplicate. Mark it as an override so the experiment replaces the base default. Co-Authored-By: Claude Sonnet 4.6 --- .../ml/linear_probe_collapse_alterations_to_other.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml b/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml index 74c089ab..33996716 100644 --- a/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml +++ b/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml @@ -2,7 +2,7 @@ defaults: - /data: dataset - - /class_mapping: collapse_alterations_to_other + - override /class_mapping: collapse_alterations_to_other - _self_ embeddings_run_id: ${dataset.mlflow_artifacts.embeddings_run_id} From 564b0b1b6e4302c8dfd9d7b1f59845b94edeb512 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 7 May 2026 23:45:23 +0200 Subject: [PATCH 005/107] fix(scripts): drop duplicate +ml= from linear-probe submit command ml/train.py uses @with_cli_args(["+ml=linear_probe"]), so the decorator already injects that arg. Passing it again on the command line caused Hydra to load configs/ml/linear_probe.yaml twice and reject duplicate defaults. Rely on the decorator and pass only +experiment=... Co-Authored-By: Claude Sonnet 4.6 --- scripts/submit_linear_probe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/submit_linear_probe.py b/scripts/submit_linear_probe.py index bca9b9c8..61e5c847 100644 --- a/scripts/submit_linear_probe.py +++ b/scripts/submit_linear_probe.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-linear-probe", - username=..., + username="vcifka", cpu=4, memory="32Gi", - gpu="A40", + gpu=None, public=False, script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", + "git clone --branch feature/linear-probe https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml.train +ml=... +experiment=...", + "uv run python -m ml.train +experiment=ml/linear_probe_collapse_alterations_to_other", ], storage=[storage.secure.PROJECTS], ) From 3a77adcc68286b43058e2747fa5fd59a381c144d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 7 May 2026 23:50:50 +0200 Subject: [PATCH 006/107] fix(ml): register random_seed/len resolvers and unflatten class_mapping refs Two interpolation problems prevented Hydra from resolving the linear-probe config: 1. configs/ml.yaml uses ${random_seed:} and configs/ml/linear_probe.yaml uses ${len:...}, but neither resolver is registered anywhere. Register both at module import time in ml/train.py. 2. The class_mapping yamls use # @package _global_, so class_mapping, class_indices, and class_names land at the config root. The references in linear_probe.yaml were doubly nested (e.g. class_mapping.class_mapping). Drop the prefix. Co-Authored-By: Claude Sonnet 4.6 --- configs/ml/linear_probe.yaml | 8 ++++---- ml/train.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/configs/ml/linear_probe.yaml b/configs/ml/linear_probe.yaml index 7b969c62..e2ef1445 100644 --- a/configs/ml/linear_probe.yaml +++ b/configs/ml/linear_probe.yaml @@ -22,8 +22,8 @@ data: embeddings_run_id: ${embeddings_run_id} kfold_run_id: ${kfold_run_id} kfold_artifact_path: kfold_split/kfold_tiles.parquet - class_mapping: ${class_mapping.class_mapping} - class_indices: ${class_mapping.class_indices} + class_mapping: ${class_mapping} + class_indices: ${class_indices} drop_unmapped: true tissue_prop_min: 0.0 class_coverage_min: 0.0 @@ -33,10 +33,10 @@ data: model: _target_: ml.models.linear_probe.LinearProbe embed_dim: ${embed_dim} - num_classes: ${len:${class_mapping.class_names}} + num_classes: ${len:${class_names}} lr: 1e-3 weight_decay: 0.0 - class_names: ${class_mapping.class_names} + class_names: ${class_names} metadata: run_name: "Linear probe (embed=${embeddings_run_id}, kfold=${kfold_run_id})" diff --git a/ml/train.py b/ml/train.py index 0e1f1ccc..c13418ea 100644 --- a/ml/train.py +++ b/ml/train.py @@ -1,3 +1,4 @@ +import secrets from typing import TYPE_CHECKING, Any import hydra @@ -5,7 +6,7 @@ import mlflow import numpy as np from hydra.utils import instantiate -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from rationai.mlkit import autolog, with_cli_args from rationai.mlkit.lightning.loggers import MLFlowLogger @@ -14,6 +15,14 @@ from ml.data.embeddings_datamodule import EmbeddingsDataModule +if not OmegaConf.has_resolver("random_seed"): + OmegaConf.register_new_resolver( + "random_seed", lambda: secrets.randbits(32), use_cache=True + ) +if not OmegaConf.has_resolver("len"): + OmegaConf.register_new_resolver("len", len) + + @with_cli_args(["+ml=linear_probe"]) @hydra.main(config_path="../configs", config_name="ml", version_base=None) @autolog From 11b19f09448ca29efbe6b00b70c606ce489c37ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 11:36:54 +0200 Subject: [PATCH 007/107] fix(ml): accept already-canonical labels in datamodule label map MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The filtered tiles parquet collapses ROI columns at tiling time, so kfold writes canonical names ("Epithelium", etc.) directly into `label`. The raw→canonical lookup built from the BB-suffixed YAML lists matched none of these and dropped the entire 1.1M-tile dataset under drop_unmapped=True. Extend _raw_to_canonical with identity entries for every canonical class so modern parquets pass through while legacy un-collapsed labels still collapse correctly. "background" stays unmapped → dropped, as intended. Co-Authored-By: Claude Opus 4.7 --- ml/PLAN_LINEAR_PROBE.md | 2 ++ ml/data/embeddings_datamodule.py | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ml/PLAN_LINEAR_PROBE.md b/ml/PLAN_LINEAR_PROBE.md index 3c6586e8..7b9388b3 100644 --- a/ml/PLAN_LINEAR_PROBE.md +++ b/ml/PLAN_LINEAR_PROBE.md @@ -23,6 +23,8 @@ Current config hardcodes `${project_path}/embeddings/${embeddings_run_id}/...`. ### 1.3 Raw labels in kfold parquet are not the 9 canonical classes `kfold_split.py` writes `label = roi_coverage_` argmax → values like `"EPITHELIUM-BB"`, `"NEOPLASTIC-MALIGNANT"`, or `"background"`. The probe expects canonical names (`Epithelium`, `Neoplastic`, …). → apply `class_mapping` (raw → canonical) inside the datamodule. Tiles whose raw label isn't covered by the mapping (today: `"background"`) need a policy — see §1.4. +**Update (post-smoke-run):** the *filtered* tiles parquet — which is what kfold now runs over (commit `64b2000`) — already has the ROI columns collapsed upstream (`roi_coverage_Nerve`, `roi_coverage_Epithelium`, …). So `kfold_split.derive_labels` strips the prefix and writes already-canonical names (`"Nerve"`, `"Epithelium"`, …) plus `"background"` straight into `label`. The raw→canonical lookup built from the YAML's BB-suffixed lists matches **none** of these, so with `drop_unmapped=True` the entire dataset gets dropped (observed: `n_tiles_after_label_map = 0` on a 1.1M-tile run). Fix: extend `_raw_to_canonical` with identity entries for every canonical class (`{c: c for c in class_indices}`). This keeps backward-compat with legacy un-collapsed parquets while letting modern canonical-label parquets pass through unchanged. `"background"` stays unmapped → dropped, which is what we want (no ROI overlap = no label). + ### 1.4 Background and coverage-threshold filtering **Background**: `collapse_alterations_to_other.yaml` has no Background class. `filter_tiles.py` already drops tiles with zero tissue coverage and zero annotation coverage upstream — so by the time we reach the kfold parquet, `"background"` rows can still appear (a tile can have tissue but no annotation overlap, or vice-versa, depending on the filter logic). Drop any rows whose raw label isn't in the mapping. Config knob: `drop_unmapped: true` (default true). diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py index 3e7d340a..459bfef0 100644 --- a/ml/data/embeddings_datamodule.py +++ b/ml/data/embeddings_datamodule.py @@ -59,12 +59,17 @@ def __init__( else class_indices ) self._class_mapping: dict[str, list[str]] = dict(cm) + self._class_indices: dict[str, int] = dict(ci) self._raw_to_canonical: dict[str, str] = { raw: canonical for canonical, raws in self._class_mapping.items() for raw in raws } - self._class_indices: dict[str, int] = dict(ci) + # Accept already-canonical labels as identity. The filtered tiles parquet + # collapses ROI columns at tiling time, so kfold writes canonical names + # (e.g. "Epithelium") directly into `label`; the raw→canonical lists in + # the class-mapping YAML still cover legacy un-collapsed parquets. + self._raw_to_canonical.update({c: c for c in self._class_indices}) self._emb_dir: Path | None = None self._kfold_path: Path | None = None self.full_dataset: Dataset | None = None From c6bfe8eca3068293af83552f0f7954f37d4da646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 11:40:57 +0200 Subject: [PATCH 008/107] feat(ml): class-weighted CE, raise class_coverage_min to 0.5 - Add EmbeddingsDataModule.compute_class_weights("balanced"|"inverse") using sklearn-style weights from the current train fold. - train.py resolves class_weights="balanced"/"inverse" via the datamodule and passes the resulting list to LinearProbe at instantiate time (per-fold, since splits change). - Bump class_coverage_min from 0.0 to 0.5 to drop mosaic tiles. - Drop the redundant /class_mapping default from configs/ml/linear_probe.yaml; experiment files now own the choice. Co-Authored-By: Claude Opus 4.7 --- ...r_probe_collapse_alterations_to_other.yaml | 2 +- configs/ml/linear_probe.yaml | 6 ++--- ml/data/embeddings_datamodule.py | 24 +++++++++++++++++++ ml/train.py | 8 ++++++- 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml b/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml index 33996716..74c089ab 100644 --- a/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml +++ b/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml @@ -2,7 +2,7 @@ defaults: - /data: dataset - - override /class_mapping: collapse_alterations_to_other + - /class_mapping: collapse_alterations_to_other - _self_ embeddings_run_id: ${dataset.mlflow_artifacts.embeddings_run_id} diff --git a/configs/ml/linear_probe.yaml b/configs/ml/linear_probe.yaml index e2ef1445..0e2f9cbb 100644 --- a/configs/ml/linear_probe.yaml +++ b/configs/ml/linear_probe.yaml @@ -1,8 +1,5 @@ # @package _global_ -defaults: - - /class_mapping: collapse_alterations_to_other - mode: fit embed_dim: 2560 @@ -26,7 +23,7 @@ data: class_indices: ${class_indices} drop_unmapped: true tissue_prop_min: 0.0 - class_coverage_min: 0.0 + class_coverage_min: 0.5 batch_size: 1024 num_workers: 4 @@ -37,6 +34,7 @@ model: lr: 1e-3 weight_decay: 0.0 class_names: ${class_names} + class_weights: balanced # null | "balanced" | "inverse" | list[float] metadata: run_name: "Linear probe (embed=${embeddings_run_id}, kfold=${kfold_run_id})" diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py index 459bfef0..7f9472b7 100644 --- a/ml/data/embeddings_datamodule.py +++ b/ml/data/embeddings_datamodule.py @@ -208,6 +208,30 @@ def set_val_fold(self, fold: int) -> None: "torch", columns=["embedding", "target"] ) + def compute_class_weights(self, method: str = "balanced") -> list[float]: + """Compute per-class loss weights from the current training fold. + + ``balanced`` follows sklearn's ``compute_class_weight``: + ``w_c = n_samples / (n_classes * count_c)``. ``inverse`` uses + ``1 / count_c`` normalised to mean 1. + """ + if self.full_dataset is None or self._fold_array is None: + raise RuntimeError("call setup()/set_val_fold() before compute_class_weights()") + targets = np.asarray(self.full_dataset.data.column("target"), dtype=np.int64) + train_mask = self._fold_array != self._val_fold + train_targets = targets[train_mask] + n_classes = len(self._class_indices) + counts = np.bincount(train_targets, minlength=n_classes).astype(np.float64) + counts = np.maximum(counts, 1.0) + if method == "balanced": + weights = train_targets.size / (n_classes * counts) + elif method == "inverse": + weights = 1.0 / counts + weights = weights / weights.mean() + else: + raise ValueError(f"Unknown class-weight method: {method!r}") + return weights.tolist() + # ------------------------------------------------------------------ # dataloaders # ------------------------------------------------------------------ diff --git a/ml/train.py b/ml/train.py index c13418ea..0fbfdd03 100644 --- a/ml/train.py +++ b/ml/train.py @@ -45,7 +45,13 @@ def _fit(config: DictConfig, logger: MLFlowLogger) -> None: pl.seed_everything(config.seed + fold) datamodule.set_val_fold(fold) - model: pl.LightningModule = instantiate(config.model) + weights_spec = config.model.get("class_weights", None) + weights = ( + datamodule.compute_class_weights(weights_spec) + if isinstance(weights_spec, str) + else weights_spec + ) + model: pl.LightningModule = instantiate(config.model, class_weights=weights) trainer: pl.Trainer = instantiate(config.trainer, logger=logger) trainer.fit(model, datamodule=datamodule) From 894c27b03dece5eb68620665786a9e404e978364 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 11:54:42 +0200 Subject: [PATCH 009/107] fix: sort only tiles parquet --- ml/data/embeddings_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py index 7f9472b7..dbf3fa29 100644 --- a/ml/data/embeddings_datamodule.py +++ b/ml/data/embeddings_datamodule.py @@ -158,7 +158,7 @@ def setup(self, stage: str) -> None: labels_df["target"] = labels_df["label"].map(self._class_indices).astype(int) # --- load embeddings (memory-mapped Arrow) --- - emb_files = sorted(self._emb_dir.rglob("*.parquet")) + emb_files = sorted((self._emb_dir / "tiles").rglob("*.parquet")) emb_table = pad.dataset([str(f) for f in emb_files], format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) From fc824adc7f3a2181bc586b0afa237bd393f95c4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 12:08:54 +0200 Subject: [PATCH 010/107] fix: log join types of tile keys --- ml/data/embeddings_datamodule.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py index dbf3fa29..2fa274ee 100644 --- a/ml/data/embeddings_datamodule.py +++ b/ml/data/embeddings_datamodule.py @@ -162,6 +162,7 @@ def setup(self, stage: str) -> None: emb_table = pad.dataset([str(f) for f in emb_files], format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) + log.info("Embeddings schema: %s", emb_table.schema) mlflow.log_metric("n_embeddings", len(emb_table)) # --- inner-join on (slide_id, x, y) --- @@ -169,8 +170,36 @@ def setup(self, stage: str) -> None: labels_df[["slide_id", "x", "y", "fold", "tissue_prop", "label", "target"]], preserve_index=False, ) + log.info("Labels schema: %s", labels_table.schema) mlflow.log_metric("n_labels", len(labels_table)) + # Acero join requires concrete, matching types on join keys. Normalise both + # tables: slide_id → large_string, x/y → int64. Handles null-type columns + # that can arise when Ray writes parquets with an inferred null schema, and + # type mismatches between string vs large_string across files. + join_key_types: dict[str, pa.DataType] = { + "slide_id": pa.large_string(), + "x": pa.int64(), + "y": pa.int64(), + } + for tbl_name, tbl in [("emb", emb_table), ("labels", labels_table)]: + for col_name, target_type in join_key_types.items(): + idx = tbl.schema.get_field_index(col_name) + actual_type = tbl.schema.field(col_name).type + if actual_type != target_type: + log.warning( + "%s.%s has type %s, casting to %s", + tbl_name, col_name, actual_type, target_type, + ) + if tbl_name == "emb": + emb_table = emb_table.set_column( + idx, col_name, emb_table.column(col_name).cast(target_type) + ) + else: + labels_table = labels_table.set_column( + idx, col_name, labels_table.column(col_name).cast(target_type) + ) + joined = emb_table.join( labels_table, keys=["slide_id", "x", "y"], From 11931d12c7346ef6dba8ff6dc8810dd14eb48f39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 12:25:50 +0200 Subject: [PATCH 011/107] fix: remove embeddings from the join --- ml/data/embeddings_datamodule.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py index 2fa274ee..1b5827af 100644 --- a/ml/data/embeddings_datamodule.py +++ b/ml/data/embeddings_datamodule.py @@ -177,12 +177,18 @@ def setup(self, stage: str) -> None: # tables: slide_id → large_string, x/y → int64. Handles null-type columns # that can arise when Ray writes parquets with an inferred null schema, and # type mismatches between string vs large_string across files. + # Acero also does not support list-typed non-key fields, so the embedding + # column is excluded from the join and reattached via row-index lookup. join_key_types: dict[str, pa.DataType] = { "slide_id": pa.large_string(), "x": pa.int64(), "y": pa.int64(), } - for tbl_name, tbl in [("emb", emb_table), ("labels", labels_table)]: + + emb_keys = emb_table.select(["slide_id", "x", "y"]).append_column( + "_emb_row", pa.array(range(len(emb_table)), type=pa.int64()) + ) + for tbl_name, tbl in [("emb_keys", emb_keys), ("labels", labels_table)]: for col_name, target_type in join_key_types.items(): idx = tbl.schema.get_field_index(col_name) actual_type = tbl.schema.field(col_name).type @@ -191,20 +197,25 @@ def setup(self, stage: str) -> None: "%s.%s has type %s, casting to %s", tbl_name, col_name, actual_type, target_type, ) - if tbl_name == "emb": - emb_table = emb_table.set_column( - idx, col_name, emb_table.column(col_name).cast(target_type) + if tbl_name == "emb_keys": + emb_keys = emb_keys.set_column( + idx, col_name, emb_keys.column(col_name).cast(target_type) ) else: labels_table = labels_table.set_column( idx, col_name, labels_table.column(col_name).cast(target_type) ) - joined = emb_table.join( + joined_meta = emb_keys.join( labels_table, keys=["slide_id", "x", "y"], join_type="inner", ) + emb_indices = joined_meta.column("_emb_row") + joined = joined_meta.drop_columns(["_emb_row"]).append_column( + pa.field("embedding", emb_table.schema.field("embedding").type), + emb_table.column("embedding").take(emb_indices), + ) n_joined = len(joined) mlflow.log_metric("n_joined", n_joined) From fb6b32020c70bc87f68f7d5b258a97d93484338f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 12:45:43 +0200 Subject: [PATCH 012/107] fix: remove label column --- ml/data/embeddings_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py index 1b5827af..6b279912 100644 --- a/ml/data/embeddings_datamodule.py +++ b/ml/data/embeddings_datamodule.py @@ -167,7 +167,7 @@ def setup(self, stage: str) -> None: # --- inner-join on (slide_id, x, y) --- labels_table = pa.Table.from_pandas( - labels_df[["slide_id", "x", "y", "fold", "tissue_prop", "label", "target"]], + labels_df[["slide_id", "x", "y", "fold", "tissue_prop", "target"]], preserve_index=False, ) log.info("Labels schema: %s", labels_table.schema) From 7434ae95a305fe8b496a25cf95845df2898cd528 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 14:58:57 +0200 Subject: [PATCH 013/107] fix: prevent overflow --- ml/data/embeddings_datamodule.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py index 6b279912..fc2fe74e 100644 --- a/ml/data/embeddings_datamodule.py +++ b/ml/data/embeddings_datamodule.py @@ -212,9 +212,17 @@ def setup(self, stage: str) -> None: join_type="inner", ) emb_indices = joined_meta.column("_emb_row") + + emb_type = emb_table.schema.field("embedding").type + emb_col = emb_table.column("embedding") + if pa.types.is_list(emb_type): + large_list_type = pa.large_list(emb_type.value_type) + emb_col = emb_col.cast(large_list_type) + emb_type = large_list_type + joined = joined_meta.drop_columns(["_emb_row"]).append_column( - pa.field("embedding", emb_table.schema.field("embedding").type), - emb_table.column("embedding").take(emb_indices), + pa.field("embedding", emb_type), + emb_col.take(emb_indices), ) n_joined = len(joined) mlflow.log_metric("n_joined", n_joined) From bef70dfd0502f33ac4224986da6a2bb19ed6cc1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 21:27:58 +0200 Subject: [PATCH 014/107] feat: add embedding dataset build pipeline Extract derive_labels logic to shared preprocessing/_labels.py, then use it in both split/kfold_split.py and the new embedding_dataset pipeline. The new pipeline joins k-fold (train) / filter_tiles (test) tile metadata with precomputed embeddings after applying tissue + per-dominant-class ROI thresholds, and emits a SlidesTilesLoader-compatible Parquet dataset as an MLflow artifact. Co-Authored-By: Claude Sonnet 4.6 --- .../preprocessing/embedding_dataset.yaml | 16 ++ configs/preprocessing/embedding_dataset.yaml | 13 ++ preprocessing/_labels.py | 24 ++ preprocessing/embedding_dataset.py | 212 ++++++++++++++++++ scripts/submit_embedding_dataset.py | 18 ++ split/kfold_split.py | 7 +- 6 files changed, 286 insertions(+), 4 deletions(-) create mode 100644 configs/experiment/preprocessing/embedding_dataset.yaml create mode 100644 configs/preprocessing/embedding_dataset.yaml create mode 100644 preprocessing/_labels.py create mode 100644 preprocessing/embedding_dataset.py create mode 100644 scripts/submit_embedding_dataset.py diff --git a/configs/experiment/preprocessing/embedding_dataset.yaml b/configs/experiment/preprocessing/embedding_dataset.yaml new file mode 100644 index 00000000..bfe24a04 --- /dev/null +++ b/configs/experiment/preprocessing/embedding_dataset.yaml @@ -0,0 +1,16 @@ +# @package _global_ + +defaults: + - /data: dataset + - _self_ + +tissue_prop_min: 0.5 +thresholds: ??? + +metadata: + run_name: Embedding dataset ${dataset.name} + description: "Join k-fold (${dataset.mlflow_artifacts.kfold_run_id}) and filter_tiles (${dataset.mlflow_artifacts.filter_tiles_run_id}) tile metadata with embeddings (${dataset.mlflow_artifacts.embedding_run_id})." + hyperparams: + kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} + filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} diff --git a/configs/preprocessing/embedding_dataset.yaml b/configs/preprocessing/embedding_dataset.yaml new file mode 100644 index 00000000..f4af56a6 --- /dev/null +++ b/configs/preprocessing/embedding_dataset.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +mlflow_artifact_path: embedding_dataset + +tissue_prop_min: ??? +thresholds: ??? + +metadata: + run_name: "Embedding dataset ${dataset.name}" + description: "Build embedding training dataset by joining k-fold/filter_tiles tile metadata with precomputed embeddings." + hyperparams: + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} diff --git a/preprocessing/_labels.py b/preprocessing/_labels.py new file mode 100644 index 00000000..229f7dc6 --- /dev/null +++ b/preprocessing/_labels.py @@ -0,0 +1,24 @@ +"""Shared helpers for deriving tile labels from roi_coverage_* columns.""" + +from collections.abc import Mapping +from typing import Any + +import numpy as np +import pandas as pd + + +def compute_label_and_tissue_prop( + roi_data: Mapping[str, Any], + roi_cols: list[str], +) -> tuple[np.ndarray, np.ndarray]: + """Compute (label, tissue_prop) from roi_coverage_* columns. + + label = argmax across roi_cols (with ``roi_coverage_`` prefix stripped), + falling back to ``"background"`` whenever all coverages are zero. + tissue_prop = sum across roi_cols. + """ + roi_df = pd.DataFrame({col: roi_data[col] for col in roi_cols}) + tp = roi_df.sum(axis=1).to_numpy() + lbl = roi_df.idxmax(axis=1).str.removeprefix("roi_coverage_").to_numpy() + lbl[tp == 0] = "background" + return lbl, tp diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py new file mode 100644 index 00000000..1e9028c1 --- /dev/null +++ b/preprocessing/embedding_dataset.py @@ -0,0 +1,212 @@ +"""Build an embedding training dataset by joining tile metadata with embeddings. + +Joins precomputed tile embeddings with k-fold metadata (train) / filter_tiles +metadata (test), applies tissue + per-class ROI thresholds before the join, and +emits a training-ready Parquet dataset (per-split ``slides.parquet`` + +``tiles.parquet``) ready for ``rationai.mlkit.data.datasets.SlidesTilesLoader``. +""" + +import shutil +import tempfile +from pathlib import Path + +import hydra +import mlflow +import mlflow.artifacts +import pandas as pd +import pyarrow.dataset as pads +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import autolog, with_cli_args +from rationai.mlkit.lightning.loggers import MLFlowLogger + +from preprocessing._labels import compute_label_and_tissue_prop + + +def apply_thresholds( + df: pd.DataFrame, + tissue_prop_min: float, + thresholds: dict[str, float], + roi_cols: list[str], +) -> tuple[pd.DataFrame, int]: + """Filter df by tissue_prop_min then by per-dominant-class roi threshold. + + Returns ``(filtered_df, after_tissue_count)`` so the caller can log both + intermediate counts. + """ + df = df[df["tissue_prop"] >= tissue_prop_min] + after_tissue = len(df) + if df.empty: + return df, after_tissue + + roi_only = df[roi_cols] + dominant = roi_only.idxmax(axis=1).str.removeprefix("roi_coverage_") + dominant_value = roi_only.max(axis=1).to_numpy() + threshold_per_row = dominant.map(thresholds).to_numpy() + keep = dominant_value >= threshold_per_row + return df[keep].copy(), after_tissue + + +def join_embeddings( + tiles_df: pd.DataFrame, + embedding_run_id: str, + embedding_split: str, +) -> tuple[pd.DataFrame, int]: + """Join filtered tile metadata with embeddings on (slide_id, x, y).""" + emb_dir = mlflow.artifacts.download_artifacts( + run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" + ) + emb_table = pads.dataset(emb_dir, format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) + emb_df = emb_table.to_pandas() + del emb_table + + merged = tiles_df.merge(emb_df, on=["slide_id", "x", "y"], how="inner") + dropped_no_embedding = len(tiles_df) - len(merged) + return merged, dropped_no_embedding + + +def process_split( + split_name: str, + src_run_id: str, + src_artifact_path: str, + embedding_run_id: str, + tissue_prop_min: float, + thresholds: dict[str, float], + output_split_dir: Path, + derive: bool, +) -> dict[str, int]: + src_local = mlflow.artifacts.download_artifacts( + run_id=src_run_id, artifact_path=src_artifact_path + ) + df = pads.dataset(src_local, format="parquet").to_table().to_pandas() + input_count = len(df) + + roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] + if not roi_cols: + raise RuntimeError( + f"No roi_coverage_* columns in {src_artifact_path}. " + "Cannot apply class thresholds." + ) + + classes_in_data = {c.removeprefix("roi_coverage_") for c in roi_cols} + missing = classes_in_data - set(thresholds.keys()) + if missing: + raise ValueError( + f"thresholds is missing entries for roi_coverage_* classes present " + f"in data: {sorted(missing)}" + ) + + if derive: + lbl, tp = compute_label_and_tissue_prop(df, roi_cols) + df["label"] = lbl + df["tissue_prop"] = tp + + df, after_tissue_filter = apply_thresholds( + df, tissue_prop_min, thresholds, roi_cols + ) + after_class_threshold = len(df) + if after_class_threshold == 0: + raise RuntimeError( + f"All {input_count} tiles dropped by thresholds for split '{split_name}'." + ) + + drop_cols = [ + c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) + ] + df = df.drop(columns=drop_cols) + + merged, dropped_no_embedding = join_embeddings(df, embedding_run_id, split_name) + if dropped_no_embedding != 0: + print( + f"WARNING: {dropped_no_embedding} tiles in split '{split_name}' have " + "no matching embedding and were dropped on join.", + flush=True, + ) + + merged = merged.sort_values("slide_id", kind="stable").reset_index(drop=True) + + output_split_dir.mkdir(parents=True, exist_ok=True) + merged.to_parquet(output_split_dir / "tiles.parquet", index=False) + + slides_local = mlflow.artifacts.download_artifacts( + run_id=embedding_run_id, artifact_path=f"{split_name}/slides.parquet" + ) + shutil.copy(slides_local, output_split_dir / "slides.parquet") + + log_label_distributions(split_name, merged) + + return { + "input_count": input_count, + "after_tissue_filter": after_tissue_filter, + "after_class_threshold": after_class_threshold, + "after_join": len(merged), + "dropped_no_embedding": dropped_no_embedding, + } + + +def log_label_distributions(split_name: str, df: pd.DataFrame) -> None: + label_dist = ( + df["label"].value_counts().rename_axis("label").reset_index(name="count") + ) + mlflow.log_table( + data=label_dist, + artifact_file=f"fold_statistics/{split_name}_label_distribution.json", + ) + + if "fold" in df.columns: + fold_dist = ( + df.groupby(["fold", "label"]).size().unstack(fill_value=0).reset_index() + ) + mlflow.log_table( + data=fold_dist, + artifact_file=f"fold_statistics/{split_name}_fold_label_distribution.json", + ) + + +@with_cli_args(["+preprocessing=embedding_dataset"]) +@hydra.main(config_path="../configs", config_name="preprocessing", version_base=None) +@autolog +def main(config: DictConfig, logger: MLFlowLogger) -> None: + artifacts = config.dataset.mlflow_artifacts + kfold_run_id = artifacts.kfold_run_id + filter_tiles_run_id = artifacts.filter_tiles_run_id + embedding_run_id = artifacts.embedding_run_id + + tissue_prop_min = float(config.tissue_prop_min) + if tissue_prop_min <= 0: + raise ValueError( + f"tissue_prop_min must be > 0 (got {tissue_prop_min}); " + "otherwise background tiles are not filtered out." + ) + raw_thresholds = OmegaConf.to_container(config.thresholds, resolve=True) + if not isinstance(raw_thresholds, dict): + raise TypeError("config.thresholds must be a mapping of class -> threshold") + thresholds = {str(k): float(v) for k, v in raw_thresholds.items()} + + splits = [ + ("train", kfold_run_id, "kfold_split/kfold_tiles.parquet", False), + ("test", filter_tiles_run_id, "filter_tiles/test_tiles.parquet", True), + ] + + with tempfile.TemporaryDirectory() as tmp_root: + tmp_root_path = Path(tmp_root) + for split_name, src_run_id, src_artifact_path, derive in splits: + stats = process_split( + split_name=split_name, + src_run_id=src_run_id, + src_artifact_path=src_artifact_path, + embedding_run_id=embedding_run_id, + tissue_prop_min=tissue_prop_min, + thresholds=thresholds, + output_split_dir=tmp_root_path / split_name, + derive=derive, + ) + for key, value in stats.items(): + mlflow.log_metric(f"{split_name}_{key}", value) + + mlflow.log_artifacts(str(tmp_root_path), config.mlflow_artifact_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/submit_embedding_dataset.py b/scripts/submit_embedding_dataset.py new file mode 100644 index 00000000..bbe4063f --- /dev/null +++ b/scripts/submit_embedding_dataset.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-embedding-dataset", + username=..., + cpu=8, + memory="32Gi", + gpu=None, + public=False, + script=[ + "git clone https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m preprocessing.embedding_dataset +experiment=...", + ], + storage=[storage.secure.PROJECTS], +) diff --git a/split/kfold_split.py b/split/kfold_split.py index 150961ee..a17c299a 100644 --- a/split/kfold_split.py +++ b/split/kfold_split.py @@ -12,6 +12,8 @@ from rationai.mlkit.lightning.loggers import MLFlowLogger from sklearn.model_selection import StratifiedKFold +from preprocessing._labels import compute_label_and_tissue_prop + def derive_labels( dataset: Dataset, @@ -20,10 +22,7 @@ def derive_labels( """Derive label, tissue_prop, and slide_id arrays from the dataset.""" def compute(batch: dict[str, Any]) -> dict[str, Any]: - roi_df = pd.DataFrame({col: batch[col] for col in roi_cols}) - tp = roi_df.sum(axis=1).values - lbl = roi_df.idxmax(axis=1).str.removeprefix("roi_coverage_").values - lbl[tp == 0] = "background" + lbl, tp = compute_label_and_tissue_prop(batch, roi_cols) return {"label": lbl.tolist(), "tissue_prop": tp.tolist()} label_ds = dataset.select_columns(["slide_id", *roi_cols]).map( From 911bec2c7b40149a681ab7be97854ab32e06bdff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 21:35:16 +0200 Subject: [PATCH 015/107] feat: add class tresholds and run ids --- configs/data/dataset.yaml | 2 ++ .../experiment/preprocessing/embedding_dataset.yaml | 11 +++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index e13fec8d..732575b2 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,6 +14,8 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" + kfold_run_id: "850c81506684450b9af92296acfd045a" + embedding_run_id: "06d2d8eb088c4e04b04435940774c7aa" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" diff --git a/configs/experiment/preprocessing/embedding_dataset.yaml b/configs/experiment/preprocessing/embedding_dataset.yaml index bfe24a04..71f3e687 100644 --- a/configs/experiment/preprocessing/embedding_dataset.yaml +++ b/configs/experiment/preprocessing/embedding_dataset.yaml @@ -4,8 +4,15 @@ defaults: - /data: dataset - _self_ -tissue_prop_min: 0.5 -thresholds: ??? +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.0 + Fat: 0.0 + Epithelium: 0.0 + Muscle: 0.0 + Other: 0.0 metadata: run_name: Embedding dataset ${dataset.name} From 1a0239537e59eddb586889085fd1fec9ad193e2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 21:43:31 +0200 Subject: [PATCH 016/107] fix: wrong run id --- configs/data/dataset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 732575b2..0497d479 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -15,7 +15,7 @@ dataset: tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" kfold_run_id: "850c81506684450b9af92296acfd045a" - embedding_run_id: "06d2d8eb088c4e04b04435940774c7aa" + embedding_run_id: "f05076dcd5e64cb2839efe5fb20a22ae" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" From b38465e6185acae770428314752d2c0cf26c7541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 16:34:37 +0200 Subject: [PATCH 017/107] feat: add timing --- configs/data/dataset.yaml | 2 +- preprocessing/embedding_dataset.py | 28 +++++++++++++++++++++++++--- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 0497d479..0cf33e25 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -15,7 +15,7 @@ dataset: tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" kfold_run_id: "850c81506684450b9af92296acfd045a" - embedding_run_id: "f05076dcd5e64cb2839efe5fb20a22ae" + embedding_run_id: "5f323d5ef5a74026846ecbe8fbc007fb" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 1e9028c1..741ffdd0 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,6 +8,7 @@ import shutil import tempfile +import time from pathlib import Path import hydra @@ -52,16 +53,28 @@ def join_embeddings( embedding_split: str, ) -> tuple[pd.DataFrame, int]: """Join filtered tile metadata with embeddings on (slide_id, x, y).""" + t0 = time.time() emb_dir = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" ) - emb_table = pads.dataset(emb_dir, format="parquet").to_table( - columns=["slide_id", "x", "y", "embedding"] - ) + print(f"[timing] download embeddings: {time.time() - t0:.1f}s", flush=True) + + t0 = time.time() + emb_ds = pads.dataset(emb_dir, format="parquet") + print(f"[timing] embedding dataset has {emb_ds.count_rows()} rows", flush=True) + emb_table = emb_ds.to_table(columns=["slide_id", "x", "y", "embedding"]) + print(f"[timing] to_table: {time.time() - t0:.1f}s", flush=True) + + t0 = time.time() emb_df = emb_table.to_pandas() del emb_table + print(f"[timing] to_pandas: {time.time() - t0:.1f}s shape={emb_df.shape}", flush=True) + t0 = time.time() merged = tiles_df.merge(emb_df, on=["slide_id", "x", "y"], how="inner") + print(f"[timing] merge: {time.time() - t0:.1f}s shape={merged.shape}", flush=True) + del emb_df + dropped_no_embedding = len(tiles_df) - len(merged) return merged, dropped_no_embedding @@ -76,11 +89,14 @@ def process_split( output_split_dir: Path, derive: bool, ) -> dict[str, int]: + print(f"[{split_name}] downloading src tiles...", flush=True) + t0 = time.time() src_local = mlflow.artifacts.download_artifacts( run_id=src_run_id, artifact_path=src_artifact_path ) df = pads.dataset(src_local, format="parquet").to_table().to_pandas() input_count = len(df) + print(f"[{split_name}] src tiles loaded: {input_count} rows {time.time() - t0:.1f}s", flush=True) roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] if not roi_cols: @@ -115,6 +131,7 @@ def process_split( c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) ] df = df.drop(columns=drop_cols) + print(f"[{split_name}] after thresholds: {after_class_threshold} rows, joining embeddings...", flush=True) merged, dropped_no_embedding = join_embeddings(df, embedding_run_id, split_name) if dropped_no_embedding != 0: @@ -124,11 +141,16 @@ def process_split( flush=True, ) + t0 = time.time() merged = merged.sort_values("slide_id", kind="stable").reset_index(drop=True) + print(f"[{split_name}] sort: {time.time() - t0:.1f}s", flush=True) output_split_dir.mkdir(parents=True, exist_ok=True) + t0 = time.time() merged.to_parquet(output_split_dir / "tiles.parquet", index=False) + print(f"[{split_name}] write parquet: {time.time() - t0:.1f}s", flush=True) + print(f"[{split_name}] downloading slides.parquet...", flush=True) slides_local = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{split_name}/slides.parquet" ) From bfc9578a83747a4a07eeb82a1d20ccc67368e5cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 16:50:40 +0200 Subject: [PATCH 018/107] refactor: use pyarrow to avoid to pandas conversion --- preprocessing/embedding_dataset.py | 49 ++++++++++++++++++------------ 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 741ffdd0..771a2292 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -15,7 +15,10 @@ import mlflow import mlflow.artifacts import pandas as pd +import pyarrow as pa +import pyarrow.compute as pc import pyarrow.dataset as pads +import pyarrow.parquet as pq from omegaconf import DictConfig, OmegaConf from rationai.mlkit import autolog, with_cli_args from rationai.mlkit.lightning.loggers import MLFlowLogger @@ -48,11 +51,15 @@ def apply_thresholds( def join_embeddings( - tiles_df: pd.DataFrame, + tiles_table: pa.Table, embedding_run_id: str, embedding_split: str, -) -> tuple[pd.DataFrame, int]: - """Join filtered tile metadata with embeddings on (slide_id, x, y).""" +) -> tuple[pa.Table, int]: + """Join filtered tile metadata with embeddings on (slide_id, x, y) using Arrow join. + + Stays entirely in Arrow to avoid the slow fixed-size-list to_pandas() conversion + on the embedding column. + """ t0 = time.time() emb_dir = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" @@ -66,17 +73,12 @@ def join_embeddings( print(f"[timing] to_table: {time.time() - t0:.1f}s", flush=True) t0 = time.time() - emb_df = emb_table.to_pandas() + joined = tiles_table.join(emb_table, keys=["slide_id", "x", "y"], join_type="inner") del emb_table - print(f"[timing] to_pandas: {time.time() - t0:.1f}s shape={emb_df.shape}", flush=True) + print(f"[timing] arrow join: {time.time() - t0:.1f}s rows={joined.num_rows}", flush=True) - t0 = time.time() - merged = tiles_df.merge(emb_df, on=["slide_id", "x", "y"], how="inner") - print(f"[timing] merge: {time.time() - t0:.1f}s shape={merged.shape}", flush=True) - del emb_df - - dropped_no_embedding = len(tiles_df) - len(merged) - return merged, dropped_no_embedding + dropped_no_embedding = tiles_table.num_rows - joined.num_rows + return joined, dropped_no_embedding def process_split( @@ -133,7 +135,11 @@ def process_split( df = df.drop(columns=drop_cols) print(f"[{split_name}] after thresholds: {after_class_threshold} rows, joining embeddings...", flush=True) - merged, dropped_no_embedding = join_embeddings(df, embedding_run_id, split_name) + tiles_table = pa.Table.from_pandas(df, preserve_index=False) + del df + + merged_table, dropped_no_embedding = join_embeddings(tiles_table, embedding_run_id, split_name) + del tiles_table if dropped_no_embedding != 0: print( f"WARNING: {dropped_no_embedding} tiles in split '{split_name}' have " @@ -142,12 +148,13 @@ def process_split( ) t0 = time.time() - merged = merged.sort_values("slide_id", kind="stable").reset_index(drop=True) + sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) + merged_table = merged_table.take(sort_indices) print(f"[{split_name}] sort: {time.time() - t0:.1f}s", flush=True) output_split_dir.mkdir(parents=True, exist_ok=True) t0 = time.time() - merged.to_parquet(output_split_dir / "tiles.parquet", index=False) + pq.write_table(merged_table, str(output_split_dir / "tiles.parquet")) print(f"[{split_name}] write parquet: {time.time() - t0:.1f}s", flush=True) print(f"[{split_name}] downloading slides.parquet...", flush=True) @@ -156,18 +163,22 @@ def process_split( ) shutil.copy(slides_local, output_split_dir / "slides.parquet") - log_label_distributions(split_name, merged) + log_label_distributions(split_name, merged_table) return { "input_count": input_count, "after_tissue_filter": after_tissue_filter, "after_class_threshold": after_class_threshold, - "after_join": len(merged), + "after_join": merged_table.num_rows, "dropped_no_embedding": dropped_no_embedding, } -def log_label_distributions(split_name: str, df: pd.DataFrame) -> None: +def log_label_distributions(split_name: str, table: pa.Table) -> None: + has_fold = "fold" in table.schema.names + cols = ["label", "fold"] if has_fold else ["label"] + df = table.select(cols).to_pandas() + label_dist = ( df["label"].value_counts().rename_axis("label").reset_index(name="count") ) @@ -176,7 +187,7 @@ def log_label_distributions(split_name: str, df: pd.DataFrame) -> None: artifact_file=f"fold_statistics/{split_name}_label_distribution.json", ) - if "fold" in df.columns: + if has_fold: fold_dist = ( df.groupby(["fold", "label"]).size().unstack(fill_value=0).reset_index() ) From eb213c6abc1312bd281b5bbb40a8abacf24c4d71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 20:23:21 +0200 Subject: [PATCH 019/107] fix: join on keys only --- preprocessing/embedding_dataset.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 771a2292..061949e6 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -72,9 +72,18 @@ def join_embeddings( emb_table = emb_ds.to_table(columns=["slide_id", "x", "y", "embedding"]) print(f"[timing] to_table: {time.time() - t0:.1f}s", flush=True) + # Arrow Acero join doesn't support list in non-key fields, so join on + # keys only using a row-index column, then pull embeddings via take(). t0 = time.time() - joined = tiles_table.join(emb_table, keys=["slide_id", "x", "y"], join_type="inner") - del emb_table + emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) + emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) + + joined_keys = tiles_table.join(emb_keys, keys=["slide_id", "x", "y"], join_type="inner") + embeddings = emb_table.column("embedding").take(joined_keys.column("_emb_idx")) + del emb_table, emb_keys, emb_idx + + joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) + del joined_keys, embeddings print(f"[timing] arrow join: {time.time() - t0:.1f}s rows={joined.num_rows}", flush=True) dropped_no_embedding = tiles_table.num_rows - joined.num_rows From c92d9a1a5879ae1e8d61231a5b4184983eb4d633 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 20:35:35 +0200 Subject: [PATCH 020/107] fix: typing --- preprocessing/embedding_dataset.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 061949e6..a193b856 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -74,13 +74,25 @@ def join_embeddings( # Arrow Acero join doesn't support list in non-key fields, so join on # keys only using a row-index column, then pull embeddings via take(). + # Cast embedding to large_list first: 1.1M rows * 768 doubles overflows int32 + # list offsets when chunks are concatenated by take(). t0 = time.time() + emb_col = emb_table.column("embedding") + if pa.types.is_list(emb_col.type): + emb_col = emb_col.cast(pa.large_list(emb_col.type.value_type)) + elif pa.types.is_fixed_size_list(emb_col.type): + pass # fixed_size_list has no offsets, no overflow risk + else: + emb_col = emb_col.cast(pa.large_list(pa.float64())) + emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) + del emb_table, emb_idx joined_keys = tiles_table.join(emb_keys, keys=["slide_id", "x", "y"], join_type="inner") - embeddings = emb_table.column("embedding").take(joined_keys.column("_emb_idx")) - del emb_table, emb_keys, emb_idx + del emb_keys + embeddings = emb_col.take(joined_keys.column("_emb_idx")) + del emb_col joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) del joined_keys, embeddings From 01cc39450e00b0eba3a9573e08d493498cd71172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 21:24:45 +0200 Subject: [PATCH 021/107] fix: add prints --- preprocessing/embedding_dataset.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index a193b856..4d868a83 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -74,29 +74,40 @@ def join_embeddings( # Arrow Acero join doesn't support list in non-key fields, so join on # keys only using a row-index column, then pull embeddings via take(). - # Cast embedding to large_list first: 1.1M rows * 768 doubles overflows int32 - # list offsets when chunks are concatenated by take(). - t0 = time.time() emb_col = emb_table.column("embedding") + print(f"[timing] embedding column type={emb_col.type}, num_chunks={emb_col.num_chunks}", flush=True) + + # Cast per chunk to large_list to avoid the int32 offset overflow that hits + # when take() concatenates chunks of list. Per-chunk casts touch + # the offset buffer only (each chunk individually fits int32). + t0 = time.time() if pa.types.is_list(emb_col.type): - emb_col = emb_col.cast(pa.large_list(emb_col.type.value_type)) - elif pa.types.is_fixed_size_list(emb_col.type): - pass # fixed_size_list has no offsets, no overflow risk - else: - emb_col = emb_col.cast(pa.large_list(pa.float64())) + target_type = pa.large_list(emb_col.type.value_type) + new_chunks = [c.cast(target_type) for c in emb_col.chunks] + emb_col = pa.chunked_array(new_chunks, type=target_type) + del new_chunks + print(f"[timing] cast embedding to large_list: {time.time() - t0:.1f}s", flush=True) + t0 = time.time() emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx + print(f"[timing] build emb_keys: {time.time() - t0:.1f}s", flush=True) + t0 = time.time() joined_keys = tiles_table.join(emb_keys, keys=["slide_id", "x", "y"], join_type="inner") del emb_keys + print(f"[timing] arrow key-join: {time.time() - t0:.1f}s rows={joined_keys.num_rows}", flush=True) + + t0 = time.time() embeddings = emb_col.take(joined_keys.column("_emb_idx")) del emb_col + print(f"[timing] take embeddings: {time.time() - t0:.1f}s", flush=True) + t0 = time.time() joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) del joined_keys, embeddings - print(f"[timing] arrow join: {time.time() - t0:.1f}s rows={joined.num_rows}", flush=True) + print(f"[timing] assemble joined table: {time.time() - t0:.1f}s rows={joined.num_rows}", flush=True) dropped_no_embedding = tiles_table.num_rows - joined.num_rows return joined, dropped_no_embedding From cad0d376e0d4eb132125ade9b7654be4c36288df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 21:35:58 +0200 Subject: [PATCH 022/107] refactor: use combine chunks --- preprocessing/embedding_dataset.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 4d868a83..73987ddc 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -100,8 +100,16 @@ def join_embeddings( print(f"[timing] arrow key-join: {time.time() - t0:.1f}s rows={joined_keys.num_rows}", flush=True) t0 = time.time() - embeddings = emb_col.take(joined_keys.column("_emb_idx")) + emb_array = emb_col.combine_chunks() del emb_col + print(f"[timing] combine_chunks: {time.time() - t0:.1f}s", flush=True) + + t0 = time.time() + indices = joined_keys.column("_emb_idx") + if isinstance(indices, pa.ChunkedArray): + indices = indices.combine_chunks() + embeddings = emb_array.take(indices) + del emb_array print(f"[timing] take embeddings: {time.time() - t0:.1f}s", flush=True) t0 = time.time() From ae045526dd1f97096ff02a23740eab3fed44faf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 22:21:49 +0200 Subject: [PATCH 023/107] fix: lazy-cast embeddings to large_list and stay in Arrow during join Joining 1M+ rows of list embeddings was either OOMing on to_pandas() or hitting int32 list-offset overflow inside take(). The fix: - read embeddings into Arrow only and cast each chunk to large_list so take() concatenation uses int64 offsets; - run the join on keys plus a synthetic row index because Acero refuses list columns in non-key fields, then pull embeddings via take(); - combine_chunks() before take() for an O(N) single-pass copy; - write the parquet straight from Arrow, never materialising the embedding column in pandas. Also bumps the kube job memory to 64Gi to give the combined-chunks + take() peak some headroom, and trims the verbose [timing] prints down to one progress line per split. Co-Authored-By: Claude Sonnet 4.6 --- preprocessing/embedding_dataset.py | 83 +++++++++++------------------ scripts/submit_embedding_dataset.py | 2 +- 2 files changed, 31 insertions(+), 54 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 73987ddc..098759cc 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,7 +8,6 @@ import shutil import tempfile -import time from pathlib import Path import hydra @@ -55,68 +54,44 @@ def join_embeddings( embedding_run_id: str, embedding_split: str, ) -> tuple[pa.Table, int]: - """Join filtered tile metadata with embeddings on (slide_id, x, y) using Arrow join. + """Join filtered tile metadata with embeddings on (slide_id, x, y). - Stays entirely in Arrow to avoid the slow fixed-size-list to_pandas() conversion - on the embedding column. + Stays in Arrow throughout to avoid the very slow list -> pandas + conversion. Acero's join engine doesn't accept list columns in non-key + fields, so we join on keys plus a synthetic row index, then pull embeddings + via take(). The embedding column is cast per chunk to large_list to avoid + int32 offset overflow that bites take() when chunks are concatenated. """ - t0 = time.time() emb_dir = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" ) - print(f"[timing] download embeddings: {time.time() - t0:.1f}s", flush=True) - - t0 = time.time() - emb_ds = pads.dataset(emb_dir, format="parquet") - print(f"[timing] embedding dataset has {emb_ds.count_rows()} rows", flush=True) - emb_table = emb_ds.to_table(columns=["slide_id", "x", "y", "embedding"]) - print(f"[timing] to_table: {time.time() - t0:.1f}s", flush=True) + emb_table = pads.dataset(emb_dir, format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) - # Arrow Acero join doesn't support list in non-key fields, so join on - # keys only using a row-index column, then pull embeddings via take(). emb_col = emb_table.column("embedding") - print(f"[timing] embedding column type={emb_col.type}, num_chunks={emb_col.num_chunks}", flush=True) - - # Cast per chunk to large_list to avoid the int32 offset overflow that hits - # when take() concatenates chunks of list. Per-chunk casts touch - # the offset buffer only (each chunk individually fits int32). - t0 = time.time() if pa.types.is_list(emb_col.type): target_type = pa.large_list(emb_col.type.value_type) - new_chunks = [c.cast(target_type) for c in emb_col.chunks] - emb_col = pa.chunked_array(new_chunks, type=target_type) - del new_chunks - print(f"[timing] cast embedding to large_list: {time.time() - t0:.1f}s", flush=True) + emb_col = pa.chunked_array( + [c.cast(target_type) for c in emb_col.chunks], type=target_type + ) - t0 = time.time() emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx - print(f"[timing] build emb_keys: {time.time() - t0:.1f}s", flush=True) - t0 = time.time() - joined_keys = tiles_table.join(emb_keys, keys=["slide_id", "x", "y"], join_type="inner") + joined_keys = tiles_table.join( + emb_keys, keys=["slide_id", "x", "y"], join_type="inner" + ) del emb_keys - print(f"[timing] arrow key-join: {time.time() - t0:.1f}s rows={joined_keys.num_rows}", flush=True) - - t0 = time.time() - emb_array = emb_col.combine_chunks() - del emb_col - print(f"[timing] combine_chunks: {time.time() - t0:.1f}s", flush=True) - t0 = time.time() indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - embeddings = emb_array.take(indices) - del emb_array - print(f"[timing] take embeddings: {time.time() - t0:.1f}s", flush=True) + embeddings = emb_col.combine_chunks().take(indices) + del emb_col - t0 = time.time() joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) - del joined_keys, embeddings - print(f"[timing] assemble joined table: {time.time() - t0:.1f}s rows={joined.num_rows}", flush=True) - dropped_no_embedding = tiles_table.num_rows - joined.num_rows return joined, dropped_no_embedding @@ -131,14 +106,12 @@ def process_split( output_split_dir: Path, derive: bool, ) -> dict[str, int]: - print(f"[{split_name}] downloading src tiles...", flush=True) - t0 = time.time() + print(f"[{split_name}] downloading source tiles", flush=True) src_local = mlflow.artifacts.download_artifacts( run_id=src_run_id, artifact_path=src_artifact_path ) df = pads.dataset(src_local, format="parquet").to_table().to_pandas() input_count = len(df) - print(f"[{split_name}] src tiles loaded: {input_count} rows {time.time() - t0:.1f}s", flush=True) roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] if not roi_cols: @@ -173,12 +146,18 @@ def process_split( c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) ] df = df.drop(columns=drop_cols) - print(f"[{split_name}] after thresholds: {after_class_threshold} rows, joining embeddings...", flush=True) + print( + f"[{split_name}] {input_count} -> {after_tissue_filter} (tissue) " + f"-> {after_class_threshold} (class threshold), joining embeddings", + flush=True, + ) tiles_table = pa.Table.from_pandas(df, preserve_index=False) del df - merged_table, dropped_no_embedding = join_embeddings(tiles_table, embedding_run_id, split_name) + merged_table, dropped_no_embedding = join_embeddings( + tiles_table, embedding_run_id, split_name + ) del tiles_table if dropped_no_embedding != 0: print( @@ -187,23 +166,21 @@ def process_split( flush=True, ) - t0 = time.time() - sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) + sort_indices = pc.sort_indices( + merged_table, sort_keys=[("slide_id", "ascending")] + ) merged_table = merged_table.take(sort_indices) - print(f"[{split_name}] sort: {time.time() - t0:.1f}s", flush=True) output_split_dir.mkdir(parents=True, exist_ok=True) - t0 = time.time() pq.write_table(merged_table, str(output_split_dir / "tiles.parquet")) - print(f"[{split_name}] write parquet: {time.time() - t0:.1f}s", flush=True) - print(f"[{split_name}] downloading slides.parquet...", flush=True) slides_local = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{split_name}/slides.parquet" ) shutil.copy(slides_local, output_split_dir / "slides.parquet") log_label_distributions(split_name, merged_table) + print(f"[{split_name}] wrote {merged_table.num_rows} rows", flush=True) return { "input_count": input_count, diff --git a/scripts/submit_embedding_dataset.py b/scripts/submit_embedding_dataset.py index bbe4063f..23977df5 100644 --- a/scripts/submit_embedding_dataset.py +++ b/scripts/submit_embedding_dataset.py @@ -5,7 +5,7 @@ job_name="tissue-classification-embedding-dataset", username=..., cpu=8, - memory="32Gi", + memory="64Gi", gpu=None, public=False, script=[ From 82320db480999dec51cac0f3ae32f59501d83386 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 22:23:03 +0200 Subject: [PATCH 024/107] fix: validate label/tissue_prop columns when derive=False Without this guard a malformed train artifact would crash deep inside apply_thresholds with a confusing KeyError. Surface a clear error that points at the expected upstream artifact instead. Co-Authored-By: Claude Sonnet 4.6 --- preprocessing/embedding_dataset.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 098759cc..22242fc9 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -132,6 +132,15 @@ def process_split( lbl, tp = compute_label_and_tissue_prop(df, roi_cols) df["label"] = lbl df["tissue_prop"] = tp + else: + required = {"label", "tissue_prop"} + missing_required = required - set(df.columns) + if missing_required: + raise RuntimeError( + f"Source split '{split_name}' (derive=False) is missing required " + f"columns {sorted(missing_required)} in {src_artifact_path}. " + "Expected the kfold_split artifact, which writes label/tissue_prop/fold." + ) df, after_tissue_filter = apply_thresholds( df, tissue_prop_min, thresholds, roi_cols From 3b0137f95bff66c83b4516c1e0fee89928cf028b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 22:35:46 +0200 Subject: [PATCH 025/107] chore: remove time --- preprocessing/embedding_dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 22242fc9..5fd263b9 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -175,9 +175,7 @@ def process_split( flush=True, ) - sort_indices = pc.sort_indices( - merged_table, sort_keys=[("slide_id", "ascending")] - ) + sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) merged_table = merged_table.take(sort_indices) output_split_dir.mkdir(parents=True, exist_ok=True) From 8df47aae009fedc541dddbf4f28aa37e37838b1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sun, 10 May 2026 17:20:58 +0200 Subject: [PATCH 026/107] feat: add timing --- preprocessing/embedding_dataset.py | 36 +++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 5fd263b9..9a7eea61 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,6 +8,7 @@ import shutil import tempfile +import time from pathlib import Path import hydra @@ -62,34 +63,67 @@ def join_embeddings( via take(). The embedding column is cast per chunk to large_list to avoid int32 offset overflow that bites take() when chunks are concatenated. """ + t = time.time() + print("[join] downloading embeddings", flush=True) emb_dir = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" ) + print(f"[join] download: {time.time() - t:.1f}s", flush=True) + + t = time.time() + print("[join] to_table", flush=True) emb_table = pads.dataset(emb_dir, format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) + print( + f"[join] to_table: {time.time() - t:.1f}s rows={emb_table.num_rows} " + f"chunks={emb_table.column('embedding').num_chunks}", + flush=True, + ) + t = time.time() emb_col = emb_table.column("embedding") if pa.types.is_list(emb_col.type): target_type = pa.large_list(emb_col.type.value_type) emb_col = pa.chunked_array( [c.cast(target_type) for c in emb_col.chunks], type=target_type ) + print(f"[join] cast to large_list: {time.time() - t:.1f}s", flush=True) + t = time.time() emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx + print(f"[join] build keys table: {time.time() - t:.1f}s", flush=True) + t = time.time() + print("[join] arrow key-join", flush=True) joined_keys = tiles_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys + print( + f"[join] arrow key-join: {time.time() - t:.1f}s rows={joined_keys.num_rows}", + flush=True, + ) + t = time.time() indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - embeddings = emb_col.combine_chunks().take(indices) + print(f"[join] combine indices: {time.time() - t:.1f}s", flush=True) + + t = time.time() + print("[join] combine_chunks(embeddings)", flush=True) + emb_contig = emb_col.combine_chunks() del emb_col + print(f"[join] combine_chunks: {time.time() - t:.1f}s", flush=True) + + t = time.time() + print("[join] take(embeddings)", flush=True) + embeddings = emb_contig.take(indices) + del emb_contig + print(f"[join] take: {time.time() - t:.1f}s", flush=True) joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) dropped_no_embedding = tiles_table.num_rows - joined.num_rows From 926753d54e072d70b121b70c13cc840809d20527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sun, 10 May 2026 17:48:01 +0200 Subject: [PATCH 027/107] chore: revert to the previous state --- preprocessing/embedding_dataset.py | 36 +----------------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 9a7eea61..5fd263b9 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,7 +8,6 @@ import shutil import tempfile -import time from pathlib import Path import hydra @@ -63,67 +62,34 @@ def join_embeddings( via take(). The embedding column is cast per chunk to large_list to avoid int32 offset overflow that bites take() when chunks are concatenated. """ - t = time.time() - print("[join] downloading embeddings", flush=True) emb_dir = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" ) - print(f"[join] download: {time.time() - t:.1f}s", flush=True) - - t = time.time() - print("[join] to_table", flush=True) emb_table = pads.dataset(emb_dir, format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) - print( - f"[join] to_table: {time.time() - t:.1f}s rows={emb_table.num_rows} " - f"chunks={emb_table.column('embedding').num_chunks}", - flush=True, - ) - t = time.time() emb_col = emb_table.column("embedding") if pa.types.is_list(emb_col.type): target_type = pa.large_list(emb_col.type.value_type) emb_col = pa.chunked_array( [c.cast(target_type) for c in emb_col.chunks], type=target_type ) - print(f"[join] cast to large_list: {time.time() - t:.1f}s", flush=True) - t = time.time() emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx - print(f"[join] build keys table: {time.time() - t:.1f}s", flush=True) - t = time.time() - print("[join] arrow key-join", flush=True) joined_keys = tiles_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys - print( - f"[join] arrow key-join: {time.time() - t:.1f}s rows={joined_keys.num_rows}", - flush=True, - ) - t = time.time() indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - print(f"[join] combine indices: {time.time() - t:.1f}s", flush=True) - - t = time.time() - print("[join] combine_chunks(embeddings)", flush=True) - emb_contig = emb_col.combine_chunks() + embeddings = emb_col.combine_chunks().take(indices) del emb_col - print(f"[join] combine_chunks: {time.time() - t:.1f}s", flush=True) - - t = time.time() - print("[join] take(embeddings)", flush=True) - embeddings = emb_contig.take(indices) - del emb_contig - print(f"[join] take: {time.time() - t:.1f}s", flush=True) joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) dropped_no_embedding = tiles_table.num_rows - joined.num_rows From b0e9ba4290f3078a2f173168f0bb76b3970f3acf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sun, 10 May 2026 20:21:51 +0200 Subject: [PATCH 028/107] feat: add prints --- preprocessing/embedding_dataset.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 5fd263b9..c2010208 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,6 +8,7 @@ import shutil import tempfile +import time from pathlib import Path import hydra @@ -80,16 +81,29 @@ def join_embeddings( emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx + t = time.time() joined_keys = tiles_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys + print( + f"[join] arrow key-join: {time.time() - t:.1f}s rows={joined_keys.num_rows}", + flush=True, + ) indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - embeddings = emb_col.combine_chunks().take(indices) + + t = time.time() + emb_contig = emb_col.combine_chunks() del emb_col + print(f"[join] combine_chunks: {time.time() - t:.1f}s", flush=True) + + t = time.time() + embeddings = emb_contig.take(indices) + del emb_contig + print(f"[join] take: {time.time() - t:.1f}s", flush=True) joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) dropped_no_embedding = tiles_table.num_rows - joined.num_rows From 6a915de65f4c8cf62208dcf04cf02ca5d742f1fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 10:03:36 +0200 Subject: [PATCH 029/107] refactor: use discusssed thresholds --- .../experiment/preprocessing/embedding_dataset.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/experiment/preprocessing/embedding_dataset.yaml b/configs/experiment/preprocessing/embedding_dataset.yaml index 71f3e687..8004e2e4 100644 --- a/configs/experiment/preprocessing/embedding_dataset.yaml +++ b/configs/experiment/preprocessing/embedding_dataset.yaml @@ -8,11 +8,11 @@ tissue_prop_min: 0.2 thresholds: Nerve: 0.0 Blood: 0.0 - Connective-Tissue: 0.0 - Fat: 0.0 - Epithelium: 0.0 - Muscle: 0.0 - Other: 0.0 + Connective-Tissue: 0.4 + Fat: 0.5 + Epithelium: 0.2 + Muscle: 0.4 + Other: 0.5 metadata: run_name: Embedding dataset ${dataset.name} From 0f50307daba3d5e7433b9bf7dddceda14a05f785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 10:15:05 +0200 Subject: [PATCH 030/107] refactor: use different labeling strategy --- preprocessing/embedding_dataset.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index c2010208..a1e6545a 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -14,6 +14,7 @@ import hydra import mlflow import mlflow.artifacts +import numpy as np import pandas as pd import pyarrow as pa import pyarrow.compute as pc @@ -32,22 +33,31 @@ def apply_thresholds( thresholds: dict[str, float], roi_cols: list[str], ) -> tuple[pd.DataFrame, int]: - """Filter df by tissue_prop_min then by per-dominant-class roi threshold. + """Filter df by tissue_prop_min, then keep tiles where ANY class meets its + threshold; among passing classes, the highest-coverage one becomes the label. Returns ``(filtered_df, after_tissue_count)`` so the caller can log both - intermediate counts. + intermediate counts. The returned df has its ``label`` column rewritten to + reflect the argmax-over-passers rule. """ df = df[df["tissue_prop"] >= tissue_prop_min] after_tissue = len(df) if df.empty: return df, after_tissue - roi_only = df[roi_cols] - dominant = roi_only.idxmax(axis=1).str.removeprefix("roi_coverage_") - dominant_value = roi_only.max(axis=1).to_numpy() - threshold_per_row = dominant.map(thresholds).to_numpy() - keep = dominant_value >= threshold_per_row - return df[keep].copy(), after_tissue + class_names = np.array([c.removeprefix("roi_coverage_") for c in roi_cols]) + thr = np.array([thresholds[c] for c in class_names], dtype=float) + roi = df[roi_cols].to_numpy() + passes = roi >= thr + keep = passes.any(axis=1) + + masked = np.where(passes, roi, -np.inf) + label_idx = masked.argmax(axis=1) + new_labels = class_names[label_idx] + + out = df[keep].copy() + out["label"] = new_labels[keep] + return out, after_tissue def join_embeddings( From 4d953dca0d0a619dd10291168c357725b96fed6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 12:41:07 +0200 Subject: [PATCH 031/107] feat: implement training pipeline --- configs/data/dataset.yaml | 1 + configs/experiment/ml/linear_classifier.yaml | 28 +++ configs/ml/data/embedding.yaml | 24 ++ configs/ml/model/linear_classifier.yaml | 18 ++ configs/ml/trainer/default.yaml | 29 +++ ml/__init__.py | 0 ml/__main__.py | 31 +++ ml/callbacks/__init__.py | 4 + ml/callbacks/parquet_prediction_writer.py | 77 +++++++ ml/data/__init__.py | 4 + ml/data/data_module.py | 64 ++++++ ml/data/datasets/__init__.py | 4 + ml/data/datasets/embedding_tiles.py | 78 +++++++ ml/meta_arch.py | 220 +++++++++++++++++++ ml/modeling/__init__.py | 0 ml/typing.py | 6 + preprocessing/embedding_dataset.py | 11 +- pyproject.toml | 2 + scripts/submit_train_linear.py | 18 ++ uv.lock | 4 + 20 files changed, 619 insertions(+), 4 deletions(-) create mode 100644 configs/experiment/ml/linear_classifier.yaml create mode 100644 configs/ml/data/embedding.yaml create mode 100644 configs/ml/model/linear_classifier.yaml create mode 100644 configs/ml/trainer/default.yaml create mode 100644 ml/__init__.py create mode 100644 ml/__main__.py create mode 100644 ml/callbacks/__init__.py create mode 100644 ml/callbacks/parquet_prediction_writer.py create mode 100644 ml/data/__init__.py create mode 100644 ml/data/data_module.py create mode 100644 ml/data/datasets/__init__.py create mode 100644 ml/data/datasets/embedding_tiles.py create mode 100644 ml/meta_arch.py create mode 100644 ml/modeling/__init__.py create mode 100644 ml/typing.py create mode 100644 scripts/submit_train_linear.py diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 0cf33e25..57e19255 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -16,6 +16,7 @@ dataset: filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" kfold_run_id: "850c81506684450b9af92296acfd045a" embedding_run_id: "5f323d5ef5a74026846ecbe8fbc007fb" + embedding_dataset_run_id: "3ab86e376d38481dbac5bc352f7ac7c9" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" diff --git a/configs/experiment/ml/linear_classifier.yaml b/configs/experiment/ml/linear_classifier.yaml new file mode 100644 index 00000000..6796489b --- /dev/null +++ b/configs/experiment/ml/linear_classifier.yaml @@ -0,0 +1,28 @@ +# @package _global_ + +defaults: + - /data: dataset + - /class_mapping: collapse_alterations_to_other + - /ml/trainer: default + - /ml/data: embedding + - /ml/model: linear_classifier + - _self_ + +mode: fit + +embedding_dataset_run_id: ${dataset.mlflow_artifacts.embedding_dataset_run_id} +train_tiles_uri: runs:/${embedding_dataset_run_id}/embedding_dataset/train/tiles.parquet +test_tiles_uri: runs:/${embedding_dataset_run_id}/embedding_dataset/test/tiles.parquet +val_fold: 0 + +mlflow_artifact_path: linear_classifier + +metadata: + run_name: Linear Classifier ${dataset.name} fold=${val_fold} + description: "Linear probe over frozen Virchow2 embeddings produced by embedding_dataset run ${embedding_dataset_run_id}." + hyperparams: + embedding_dataset_run_id: ${embedding_dataset_run_id} + val_fold: ${val_fold} + learning_rate: ${model.learning_rate} + weight_decay: ${model.weight_decay} + batch_size: ${data.batch_size} diff --git a/configs/ml/data/embedding.yaml b/configs/ml/data/embedding.yaml new file mode 100644 index 00000000..d80ca0c5 --- /dev/null +++ b/configs/ml/data/embedding.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +data: + batch_size: 1024 + num_workers: 4 + + train: + _target_: ml.data.datasets.EmbeddingTilesDataset + path_or_uri: ${train_tiles_uri} + class_indices: ${class_indices} + exclude_folds: + - ${val_fold} + + val: + _target_: ml.data.datasets.EmbeddingTilesDataset + path_or_uri: ${train_tiles_uri} + class_indices: ${class_indices} + include_folds: + - ${val_fold} + + test: + _target_: ml.data.datasets.EmbeddingTilesDataset + path_or_uri: ${test_tiles_uri} + class_indices: ${class_indices} diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml new file mode 100644 index 00000000..dfff43ca --- /dev/null +++ b/configs/ml/model/linear_classifier.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +model: + backbone: + _target_: torch.nn.Identity + + decode_head: + _target_: torch.nn.Linear + in_features: 2560 + out_features: ${len:${class_indices}} + + criterion: + _target_: torch.nn.CrossEntropyLoss + + class_indices: ${class_indices} + + learning_rate: 1.0e-3 + weight_decay: 0.0 diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml new file mode 100644 index 00000000..cf5766dc --- /dev/null +++ b/configs/ml/trainer/default.yaml @@ -0,0 +1,29 @@ +# @package _global_ + +trainer: + max_epochs: 50 + accelerator: auto + devices: auto + precision: 32 + log_every_n_steps: 50 + deterministic: false + + callbacks: + early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: validation/loss + mode: min + patience: 5 + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + monitor: validation/loss + mode: min + save_top_k: 1 + filename: "epoch={epoch}-val_loss={validation/loss:.4f}" + auto_insert_metric_name: false + lr_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch + prediction_writer: + _target_: ml.callbacks.ParquetPredictionWriter + output_filename: predictions.parquet diff --git a/ml/__init__.py b/ml/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/__main__.py b/ml/__main__.py new file mode 100644 index 00000000..61ef4008 --- /dev/null +++ b/ml/__main__.py @@ -0,0 +1,31 @@ +from random import randint + +import hydra +from lightning import seed_everything +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import Trainer, autolog +from rationai.mlkit.lightning.loggers import MLFlowLogger + +from ml.data import DataModule +from ml.meta_arch import MetaArch + + +OmegaConf.register_new_resolver( + "random_seed", lambda: randint(0, 2**31), use_cache=True +) +OmegaConf.register_new_resolver("len", lambda x: len(x)) + + +@hydra.main(config_path="../configs", config_name="ml", version_base=None) +@autolog +def main(config: DictConfig, logger: MLFlowLogger) -> None: + seed_everything(config.seed, workers=True) + + data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) + model = hydra.utils.instantiate(config.model, _target_=MetaArch) + trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger) + getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) + + +if __name__ == "__main__": + main() diff --git a/ml/callbacks/__init__.py b/ml/callbacks/__init__.py new file mode 100644 index 00000000..e9c20c4c --- /dev/null +++ b/ml/callbacks/__init__.py @@ -0,0 +1,4 @@ +from ml.callbacks.parquet_prediction_writer import ParquetPredictionWriter + + +__all__ = ["ParquetPredictionWriter"] diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py new file mode 100644 index 00000000..d3b91f77 --- /dev/null +++ b/ml/callbacks/parquet_prediction_writer.py @@ -0,0 +1,77 @@ +"""Aggregate ``predict_step`` outputs and write them as a parquet artifact.""" + +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +import lightning as pl +import mlflow +import numpy as np +import pandas as pd +from lightning.pytorch.callbacks import BasePredictionWriter + + +class ParquetPredictionWriter(BasePredictionWriter): + """Collect per-tile predictions and write them as a parquet artifact. + + Aggregates ``predict_step`` outputs across the predict loop, writes one + parquet file with ``slide_id``, ``target``, ``pred``, ``prob_`` + columns and logs it to the active MLflow run. + """ + def __init__(self, output_filename: str = "predictions.parquet") -> None: + super().__init__(write_interval="epoch") + self.output_filename = output_filename + self._batches: list[dict[str, Any]] = [] + + def write_on_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + prediction: dict[str, Any], + batch_indices: Sequence[int] | None, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + self._batches.append(prediction) + + def write_on_epoch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Any, + batch_indices: Any, + ) -> None: + if not self._batches: + return + + slide_ids: list[str] = [] + targets: list[int] = [] + preds: list[int] = [] + probs: list[np.ndarray] = [] + for b in self._batches: + slide_ids.extend(b["slide_id"]) + targets.extend(b["target"].tolist()) + preds.extend(b["pred"].tolist()) + probs.append(b["probs"].numpy()) + prob_matrix = np.concatenate(probs, axis=0) + + class_names = getattr(pl_module, "class_names", None) + prob_columns = ( + [f"prob_{c}" for c in class_names] + if class_names is not None and len(class_names) == prob_matrix.shape[1] + else [f"prob_{i}" for i in range(prob_matrix.shape[1])] + ) + + df = pd.DataFrame({"slide_id": slide_ids, "target": targets, "pred": preds}) + df = pd.concat([df, pd.DataFrame(prob_matrix, columns=prob_columns)], axis=1) + + out_path = Path(trainer.default_root_dir) / self.output_filename + out_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(out_path, index=False) + + active = mlflow.active_run() + if active is not None: + mlflow.log_artifact(str(out_path), artifact_path="predictions") + + self._batches.clear() diff --git a/ml/data/__init__.py b/ml/data/__init__.py new file mode 100644 index 00000000..e7058ee5 --- /dev/null +++ b/ml/data/__init__.py @@ -0,0 +1,4 @@ +from ml.data.data_module import DataModule + + +__all__ = ["DataModule"] diff --git a/ml/data/data_module.py b/ml/data/data_module.py new file mode 100644 index 00000000..c39b9503 --- /dev/null +++ b/ml/data/data_module.py @@ -0,0 +1,64 @@ +from collections.abc import Iterable + +from hydra.utils import instantiate +from lightning import LightningDataModule +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from ml.typing import Input + + +class DataModule(LightningDataModule): + """Generic Lightning datamodule that instantiates datasets lazily per stage. + + Mirrors the template pattern: ``**datasets`` accepts ``train``, ``val``, + ``test`` (or ``predict``) DictConfigs whose targets resolve to ``Dataset``s. + """ + + def __init__( + self, batch_size: int, num_workers: int = 0, **datasets: DictConfig + ) -> None: + super().__init__() + self.batch_size = batch_size + self.num_workers = num_workers + self.datasets = datasets + + def setup(self, stage: str) -> None: + match stage: + case "fit": + self.train = instantiate(self.datasets["train"]) + self.val = instantiate(self.datasets["val"]) + case "validate": + self.val = instantiate(self.datasets["val"]) + case "test": + self.test = instantiate(self.datasets["test"]) + case "predict": + self.predict = instantiate(self.datasets["predict"]) + + def train_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.train, + batch_size=self.batch_size, + shuffle=True, + drop_last=True, + num_workers=self.num_workers, + persistent_workers=self.num_workers > 0, + ) + + def val_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.val, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=self.num_workers > 0, + ) + + def test_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.test, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def predict_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.predict, batch_size=self.batch_size, num_workers=self.num_workers + ) diff --git a/ml/data/datasets/__init__.py b/ml/data/datasets/__init__.py new file mode 100644 index 00000000..cd2f91a9 --- /dev/null +++ b/ml/data/datasets/__init__.py @@ -0,0 +1,4 @@ +from ml.data.datasets.embedding_tiles import EmbeddingTilesDataset + + +__all__ = ["EmbeddingTilesDataset"] diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py new file mode 100644 index 00000000..ad4d1041 --- /dev/null +++ b/ml/data/datasets/embedding_tiles.py @@ -0,0 +1,78 @@ +"""Tile-embedding dataset. + +Reads the parquet artifact produced by ``preprocessing.embedding_dataset``. +""" + + +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from mlflow.artifacts import download_artifacts +from torch.utils.data import Dataset + +from ml.typing import Sample + + +class EmbeddingTilesDataset(Dataset[Sample]): + """Returns ``(embedding, class_index, slide_id)`` triples from a tiles parquet. + + A single dataset instance corresponds to one parquet (train or test); + fold-based CV is expressed by ``include_folds`` / ``exclude_folds`` + filters applied to the train parquet via separate dataset configs. + """ + + REQUIRED_COLUMNS = ("embedding", "label", "slide_id") + + def __init__( + self, + path_or_uri: str | Path, + class_indices: dict[str, int], + include_folds: list[int] | None = None, + exclude_folds: list[int] | None = None, + ) -> None: + df = self._load_parquet(path_or_uri) + + missing = set(self.REQUIRED_COLUMNS) - set(df.columns) + if missing: + raise ValueError(f"tiles parquet missing columns: {sorted(missing)}") + + if include_folds is not None or exclude_folds is not None: + if "fold" not in df.columns: + raise RuntimeError( + "fold filter requested but 'fold' column not in parquet" + ) + if include_folds is not None: + df = df[df["fold"].isin(include_folds)] + if exclude_folds is not None: + df = df[~df["fold"].isin(exclude_folds)] + + unknown = set(df["label"].unique()) - set(class_indices.keys()) + if unknown: + raise ValueError( + f"labels in tiles not present in class_indices: {sorted(unknown)}" + ) + + self.embeddings = np.stack(df["embedding"].tolist()).astype(np.float32) + self.labels = df["label"].map(class_indices).to_numpy(dtype=np.int64) + self.slide_ids = df["slide_id"].to_numpy() + + def __len__(self) -> int: + return len(self.labels) + + def __getitem__(self, idx: int) -> Sample: + return ( + torch.from_numpy(self.embeddings[idx]), + int(self.labels[idx]), + str(self.slide_ids[idx]), + ) + + @staticmethod + def _load_parquet(path_or_uri: str | Path) -> pd.DataFrame: + s = str(path_or_uri) + if s.startswith(("mlflow-artifacts:/", "runs:/")): + local = download_artifacts(artifact_uri=s) + else: + local = s + return pd.read_parquet(local) diff --git a/ml/meta_arch.py b/ml/meta_arch.py new file mode 100644 index 00000000..602693a1 --- /dev/null +++ b/ml/meta_arch.py @@ -0,0 +1,220 @@ +from collections import defaultdict +from collections.abc import Iterable +from typing import Any + +import mlflow +import numpy as np +import pandas as pd +import torch +from lightning import LightningModule +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from torch import Tensor, nn +from torch.optim.optimizer import Optimizer +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + MulticlassAccuracy, + MulticlassConfusionMatrix, + MulticlassF1Score, +) + +from ml.typing import Input, Outputs + + +class MetaArch(LightningModule): + """Top-level classification architecture: backbone + decode_head + criterion. + + For linear probing on precomputed embeddings, ``backbone`` is typically + ``nn.Identity`` and ``decode_head`` is a single ``nn.Linear``. + """ + + # TODO: support class_weights for CE loss when class distribution is heavily + # imbalanced. Inject either via config (list[float]) or compute from the + # train fold label distribution at setup(). + + def __init__( + self, + backbone: nn.Module, + decode_head: nn.Module, + criterion: nn.Module, + class_indices: dict[str, int], + learning_rate: float = 1e-3, + weight_decay: float = 0.0, + ) -> None: + super().__init__() + self.save_hyperparameters(ignore=["backbone", "decode_head", "criterion"]) + + self.backbone = backbone + self.decode_head = decode_head + self.criterion = criterion + + self.class_names = [ + n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) + ] + num_classes = len(self.class_names) + + macro_metrics = MetricCollection( + { + "acc_macro": MulticlassAccuracy( + num_classes=num_classes, average="macro" + ), + "f1_macro": MulticlassF1Score(num_classes=num_classes, average="macro"), + } + ) + per_class_metrics = MetricCollection( + { + "acc_per_class": MulticlassAccuracy( + num_classes=num_classes, average=None + ), + "f1_per_class": MulticlassF1Score( + num_classes=num_classes, average=None + ), + } + ) + self.val_metrics = macro_metrics.clone(prefix="validation/") + self.test_metrics = macro_metrics.clone(prefix="test/") + self.val_per_class = per_class_metrics.clone(prefix="validation/") + self.test_per_class = per_class_metrics.clone(prefix="test/") + self.val_confmat = MulticlassConfusionMatrix(num_classes=num_classes) + self.test_confmat = MulticlassConfusionMatrix(num_classes=num_classes) + + self._test_slide_correct: dict[str, int] = defaultdict(int) + self._test_slide_total: dict[str, int] = defaultdict(int) + + def forward(self, x: Tensor) -> Outputs: + features = self.backbone(x) + return self.decode_head(features) + + def training_step(self, batch: Input, batch_idx: int) -> Tensor: + inputs, targets, _ = batch + outputs = self(inputs) + loss = self.criterion(outputs, targets) + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + return loss + + def validation_step(self, batch: Input, batch_idx: int) -> None: + inputs, targets, _ = batch + outputs = self(inputs) + loss = self.criterion(outputs, targets) + self.log("validation/loss", loss, on_epoch=True, prog_bar=True) + self.val_metrics.update(outputs, targets) + self.val_per_class.update(outputs, targets) + self.val_confmat.update(outputs, targets) + self.log_dict(self.val_metrics, on_epoch=True) + + def on_validation_epoch_end(self) -> None: + self._log_per_class(self.val_per_class, "validation") + self._log_confmat(self.val_confmat, "validation") + + def test_step(self, batch: Input, batch_idx: int) -> None: + inputs, targets, slide_ids = batch + outputs = self(inputs) + self.test_metrics.update(outputs, targets) + self.test_per_class.update(outputs, targets) + self.test_confmat.update(outputs, targets) + self.log_dict(self.test_metrics, on_epoch=True) + + preds = outputs.argmax(dim=1) + correct = (preds == targets).cpu().tolist() + for slide_id, ok in zip(slide_ids, correct, strict=True): + self._test_slide_correct[slide_id] += int(ok) + self._test_slide_total[slide_id] += 1 + + def on_test_epoch_end(self) -> None: + self._log_per_class(self.test_per_class, "test") + self._log_confmat(self.test_confmat, "test") + self._log_per_slide_accuracy() + self._test_slide_correct.clear() + self._test_slide_total.clear() + + def predict_step( + self, batch: Input, batch_idx: int, dataloader_idx: int = 0 + ) -> dict[str, Any]: + inputs, targets, slide_ids = batch + outputs = self(inputs) + probs = outputs.softmax(dim=1) + preds = outputs.argmax(dim=1) + return { + "slide_id": list(slide_ids), + "target": targets.cpu(), + "pred": preds.cpu(), + "probs": probs.cpu(), + } + + def configure_optimizers(self) -> Optimizer: + return torch.optim.AdamW( + self.parameters(), + lr=self.hparams["learning_rate"], + weight_decay=self.hparams["weight_decay"], + ) + + def _log_per_class(self, collection: MetricCollection, split: str) -> None: + computed = collection.compute() + for metric_name, values in computed.items(): + tag = metric_name.split("/")[-1] # e.g. "acc_per_class" + for cls_name, val in zip(self.class_names, values.tolist(), strict=True): + self.log(f"{split}/{tag}/{cls_name}", val, on_epoch=True) + collection.reset() + + def _log_confmat(self, confmat: MulticlassConfusionMatrix, split: str) -> None: + matrix = confmat.compute().cpu().numpy() + confmat.reset() + fig = _confmat_figure(matrix, self.class_names, title=f"{split} confmat") + artifact_file = f"confusion_matrix/{split}_epoch_{self.current_epoch}.png" + try: + mlflow.log_figure(fig, artifact_file=artifact_file) + finally: + plt.close(fig) + + def _log_per_slide_accuracy(self) -> None: + accs = [ + self._test_slide_correct[s] / self._test_slide_total[s] + for s in self._test_slide_total + ] + if not accs: + return + self.log("test/slide_acc_mean", float(np.mean(accs)), on_epoch=True) + self.log("test/slide_acc_median", float(np.median(accs)), on_epoch=True) + self.log("test/slide_acc_min", float(np.min(accs)), on_epoch=True) + + rows = [ + { + "slide_id": s, + "tile_accuracy": self._test_slide_correct[s] / n, + "n_tiles": n, + } + for s, n in self._test_slide_total.items() + ] + mlflow.log_table( + data=pd.DataFrame(rows), + artifact_file="per_slide/test_tile_accuracy.json", + ) + + +def _confmat_figure( + matrix: np.ndarray, class_names: Iterable[str], title: str +) -> Figure: + fig, ax = plt.subplots(figsize=(6, 5)) + im = ax.imshow(matrix, cmap="Blues") + ax.set_title(title) + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + names = list(class_names) + ax.set_xticks(range(len(names))) + ax.set_yticks(range(len(names))) + ax.set_xticklabels(names, rotation=45, ha="right") + ax.set_yticklabels(names) + for i in range(matrix.shape[0]): + for j in range(matrix.shape[1]): + ax.text( + j, + i, + str(matrix[i, j]), + ha="center", + va="center", + color="white" if matrix[i, j] > matrix.max() / 2 else "black", + fontsize=8, + ) + fig.colorbar(im, ax=ax) + fig.tight_layout() + return fig diff --git a/ml/modeling/__init__.py b/ml/modeling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/typing.py b/ml/typing.py new file mode 100644 index 00000000..7060f5e4 --- /dev/null +++ b/ml/typing.py @@ -0,0 +1,6 @@ +import torch + + +type Sample = tuple[torch.Tensor, int, str] +type Input = tuple[torch.Tensor, torch.Tensor, list[str]] +type Outputs = torch.Tensor diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index a1e6545a..7669db9b 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -33,14 +33,17 @@ def apply_thresholds( thresholds: dict[str, float], roi_cols: list[str], ) -> tuple[pd.DataFrame, int]: - """Filter df by tissue_prop_min, then keep tiles where ANY class meets its - threshold; among passing classes, the highest-coverage one becomes the label. + """Filter tiles by tissue + per-class thresholds and rewrite labels. + + Filters ``df`` by ``tissue_prop_min``, then keeps tiles where ANY class + meets its threshold; among passing classes, the highest-coverage one + becomes the label. Returns ``(filtered_df, after_tissue_count)`` so the caller can log both intermediate counts. The returned df has its ``label`` column rewritten to reflect the argmax-over-passers rule. """ - df = df[df["tissue_prop"] >= tissue_prop_min] + df = df.loc[df["tissue_prop"] >= tissue_prop_min] after_tissue = len(df) if df.empty: return df, after_tissue @@ -55,7 +58,7 @@ def apply_thresholds( label_idx = masked.argmax(axis=1) new_labels = class_names[label_idx] - out = df[keep].copy() + out = df.loc[pd.Series(keep, index=df.index)].copy() out["label"] = new_labels[keep] return out, after_tissue diff --git a/pyproject.toml b/pyproject.toml index 183bdd91..bc54e933 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ "tifffile>=2025.12.20", "torch>=2.0.0", "torchvision>=0.15.0", + "lightning>=2.0.0", + "torchmetrics>=1.0.0", "timm>=1.0.0", "einops>=0.8.0", "matplotlib>=3.10.7", diff --git a/scripts/submit_train_linear.py b/scripts/submit_train_linear.py new file mode 100644 index 00000000..11fc7865 --- /dev/null +++ b/scripts/submit_train_linear.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-linear", + username=..., + cpu=8, + memory="64Gi", + gpu="A40", + public=False, + script=[ + "git clone https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --mutlirun", + ], + storage=[storage.secure.PROJECTS], +) diff --git a/uv.lock b/uv.lock index 30b783b9..d884e875 100644 --- a/uv.lock +++ b/uv.lock @@ -2292,6 +2292,7 @@ dependencies = [ { name = "datasets" }, { name = "einops" }, { name = "hydra-core" }, + { name = "lightning" }, { name = "matplotlib" }, { name = "mlflow" }, { name = "numpy" }, @@ -2309,6 +2310,7 @@ dependencies = [ { name = "tifffile" }, { name = "timm" }, { name = "torch" }, + { name = "torchmetrics" }, { name = "torchvision" }, { name = "tqdm" }, ] @@ -2327,6 +2329,7 @@ requires-dist = [ { name = "datasets", specifier = ">=4.0.0" }, { name = "einops", specifier = ">=0.8.0" }, { name = "hydra-core", specifier = ">=1.3.2" }, + { name = "lightning", specifier = ">=2.0.0" }, { name = "matplotlib", specifier = ">=3.10.7" }, { name = "mlflow", specifier = "<3.0.0" }, { name = "numpy", specifier = ">=2.3.5" }, @@ -2345,6 +2348,7 @@ requires-dist = [ { name = "tifffile", specifier = ">=2025.12.20" }, { name = "timm", specifier = ">=1.0.0" }, { name = "torch", specifier = ">=2.0.0" }, + { name = "torchmetrics", specifier = ">=1.0.0" }, { name = "torchvision", specifier = ">=0.15.0" }, { name = "tqdm", specifier = ">=4.66.0" }, ] From d5798bc3cb9ef1c19d167039496cccd594539252 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 18:26:58 +0200 Subject: [PATCH 032/107] feat: add class weights --- configs/data/dataset.yaml | 2 +- ml/callbacks/parquet_prediction_writer.py | 1 + ml/data/datasets/embedding_tiles.py | 1 - ml/meta_arch.py | 12 ++++++++++++ scripts/submit_train_linear.py | 8 ++++---- 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 57e19255..7419d6cc 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -16,7 +16,7 @@ dataset: filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" kfold_run_id: "850c81506684450b9af92296acfd045a" embedding_run_id: "5f323d5ef5a74026846ecbe8fbc007fb" - embedding_dataset_run_id: "3ab86e376d38481dbac5bc352f7ac7c9" + embedding_dataset_run_id: "b4a937ef6b334533807f08a191083401" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py index d3b91f77..1fe85580 100644 --- a/ml/callbacks/parquet_prediction_writer.py +++ b/ml/callbacks/parquet_prediction_writer.py @@ -18,6 +18,7 @@ class ParquetPredictionWriter(BasePredictionWriter): parquet file with ``slide_id``, ``target``, ``pred``, ``prob_`` columns and logs it to the active MLflow run. """ + def __init__(self, output_filename: str = "predictions.parquet") -> None: super().__init__(write_interval="epoch") self.output_filename = output_filename diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index ad4d1041..e60bc8f6 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -3,7 +3,6 @@ Reads the parquet artifact produced by ``preprocessing.embedding_dataset``. """ - from pathlib import Path import numpy as np diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 602693a1..1ee25753 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -81,6 +81,18 @@ def __init__( self._test_slide_correct: dict[str, int] = defaultdict(int) self._test_slide_total: dict[str, int] = defaultdict(int) + def setup(self, stage: str) -> None: + if stage == "fit": + labels = self.trainer.datamodule.train.labels + num_classes = len(self.class_names) + counts = np.bincount(labels, minlength=num_classes).astype(float) + weights = len(labels) / (num_classes * counts) + self.criterion = nn.CrossEntropyLoss( + weight=torch.tensor(weights, dtype=torch.float32) + ) + for cls, w in zip(self.class_names, weights.tolist(), strict=True): + mlflow.log_metric(f"class_weight/{cls}", w) + def forward(self, x: Tensor) -> Outputs: features = self.backbone(x) return self.decode_head(features) diff --git a/scripts/submit_train_linear.py b/scripts/submit_train_linear.py index 11fc7865..d19f6747 100644 --- a/scripts/submit_train_linear.py +++ b/scripts/submit_train_linear.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-train-linear", - username=..., + username="vcifka", cpu=8, memory="64Gi", - gpu="A40", + gpu=None, public=False, script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", + "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --mutlirun", + "uv run python -m ml +experiment=ml/linear_classifier val_fold=0,1,2,3,4 --multirun", ], storage=[storage.secure.PROJECTS], ) From ae45cd54244702ad392e6aa15265da6bdbacd5d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 20:18:05 +0200 Subject: [PATCH 033/107] refactor: join embeddings with metadata while loading the dataset --- configs/data/dataset.yaml | 3 +- configs/experiment/ml/linear_classifier.yaml | 30 +++- configs/ml/data/embedding.yaml | 15 +- ml/data/datasets/embedding_tiles.py | 159 +++++++++++++++---- ml/meta_arch.py | 6 +- pyproject.toml | 6 +- 6 files changed, 171 insertions(+), 48 deletions(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 7419d6cc..172e48fb 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -15,8 +15,7 @@ dataset: tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" kfold_run_id: "850c81506684450b9af92296acfd045a" - embedding_run_id: "5f323d5ef5a74026846ecbe8fbc007fb" - embedding_dataset_run_id: "b4a937ef6b334533807f08a191083401" + embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" diff --git a/configs/experiment/ml/linear_classifier.yaml b/configs/experiment/ml/linear_classifier.yaml index 6796489b..2e9a9525 100644 --- a/configs/experiment/ml/linear_classifier.yaml +++ b/configs/experiment/ml/linear_classifier.yaml @@ -10,19 +10,39 @@ defaults: mode: fit -embedding_dataset_run_id: ${dataset.mlflow_artifacts.embedding_dataset_run_id} -train_tiles_uri: runs:/${embedding_dataset_run_id}/embedding_dataset/train/tiles.parquet -test_tiles_uri: runs:/${embedding_dataset_run_id}/embedding_dataset/test/tiles.parquet +embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} +kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} +filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + +train_embedding_uri: runs:/${embedding_run_id}/train/tiles +test_embedding_uri: runs:/${embedding_run_id}/test/tiles +train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet +test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet + val_fold: 0 +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.4 + Fat: 0.6 + Epithelium: 0.2 + Muscle: 0.5 + Other: 0.5 + mlflow_artifact_path: linear_classifier metadata: run_name: Linear Classifier ${dataset.name} fold=${val_fold} - description: "Linear probe over frozen Virchow2 embeddings produced by embedding_dataset run ${embedding_dataset_run_id}." + description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), kfold metadata ${kfold_run_id}." hyperparams: - embedding_dataset_run_id: ${embedding_dataset_run_id} + embedding_run_id: ${embedding_run_id} + kfold_run_id: ${kfold_run_id} + filter_tiles_run_id: ${filter_tiles_run_id} val_fold: ${val_fold} + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} learning_rate: ${model.learning_rate} weight_decay: ${model.weight_decay} batch_size: ${data.batch_size} diff --git a/configs/ml/data/embedding.yaml b/configs/ml/data/embedding.yaml index d80ca0c5..597e012e 100644 --- a/configs/ml/data/embedding.yaml +++ b/configs/ml/data/embedding.yaml @@ -6,19 +6,28 @@ data: train: _target_: ml.data.datasets.EmbeddingTilesDataset - path_or_uri: ${train_tiles_uri} + embedding_uri: ${train_embedding_uri} + metadata_uri: ${train_metadata_uri} class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} exclude_folds: - ${val_fold} val: _target_: ml.data.datasets.EmbeddingTilesDataset - path_or_uri: ${train_tiles_uri} + embedding_uri: ${train_embedding_uri} + metadata_uri: ${train_metadata_uri} class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} include_folds: - ${val_fold} test: _target_: ml.data.datasets.EmbeddingTilesDataset - path_or_uri: ${test_tiles_uri} + embedding_uri: ${test_embedding_uri} + metadata_uri: ${test_metadata_uri} class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index e60bc8f6..6cf2c1df 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -1,12 +1,16 @@ """Tile-embedding dataset. -Reads the parquet artifact produced by ``preprocessing.embedding_dataset``. +Joins precomputed tile embeddings with tile metadata (k-fold parquet for train, +filter_tiles parquet for test) and applies tissue + per-class thresholds at +load time to produce ``(embedding, class_index, slide_id)`` triples. """ from pathlib import Path import numpy as np import pandas as pd +import pyarrow as pa +import pyarrow.dataset as pads import torch from mlflow.artifacts import download_artifacts from torch.utils.data import Dataset @@ -15,47 +19,81 @@ class EmbeddingTilesDataset(Dataset[Sample]): - """Returns ``(embedding, class_index, slide_id)`` triples from a tiles parquet. + """Tile-level embedding dataset with on-the-fly filtering and labeling. - A single dataset instance corresponds to one parquet (train or test); - fold-based CV is expressed by ``include_folds`` / ``exclude_folds`` - filters applied to the train parquet via separate dataset configs. - """ + Inner-joins ``embedding`` parquet with ``metadata`` parquet on + ``(slide_id, x, y)``. Metadata must contain ``roi_coverage_*`` columns; + label is the dominant class whose coverage meets its threshold. Tiles + failing the tissue proportion floor, with more than one annotated class, + or whose dominant class falls below its threshold are dropped. - REQUIRED_COLUMNS = ("embedding", "label", "slide_id") + For train/val: pass the k-fold parquet as ``metadata_uri`` and use + ``include_folds`` / ``exclude_folds`` to split. For test: pass the + filter_tiles parquet (no fold column). + """ def __init__( self, - path_or_uri: str | Path, + embedding_uri: str | Path, + metadata_uri: str | Path, class_indices: dict[str, int], + thresholds: dict[str, float], + tissue_prop_min: float, include_folds: list[int] | None = None, exclude_folds: list[int] | None = None, ) -> None: - df = self._load_parquet(path_or_uri) + meta_df = self._filter_metadata( + metadata_uri, + thresholds, + tissue_prop_min, + include_folds, + exclude_folds, + ) - missing = set(self.REQUIRED_COLUMNS) - set(df.columns) - if missing: - raise ValueError(f"tiles parquet missing columns: {sorted(missing)}") + emb_dir = self._resolve_uri(embedding_uri) + emb_table = pads.dataset(emb_dir, format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) - if include_folds is not None or exclude_folds is not None: - if "fold" not in df.columns: - raise RuntimeError( - "fold filter requested but 'fold' column not in parquet" - ) - if include_folds is not None: - df = df[df["fold"].isin(include_folds)] - if exclude_folds is not None: - df = df[~df["fold"].isin(exclude_folds)] + emb_col = emb_table.column("embedding") + if pa.types.is_list(emb_col.type): + target_type = pa.large_list(emb_col.type.value_type) + emb_col = pa.chunked_array( + [c.cast(target_type) for c in emb_col.chunks], type=target_type + ) + + emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) + emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) + del emb_table - unknown = set(df["label"].unique()) - set(class_indices.keys()) + meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) + del meta_df + joined_keys = meta_table.join( + emb_keys, keys=["slide_id", "x", "y"], join_type="inner" + ) + del emb_keys, meta_table + if joined_keys.num_rows == 0: + raise RuntimeError("inner join with embeddings produced empty dataset") + + indices = joined_keys.column("_emb_idx") + if isinstance(indices, pa.ChunkedArray): + indices = indices.combine_chunks() + embeddings_arrow = emb_col.combine_chunks().take(indices) + + embedding_dim = len(embeddings_arrow.values) // len(embeddings_arrow) + self.embeddings = ( + embeddings_arrow.values.to_numpy(zero_copy_only=False) + .astype(np.float32) + .reshape(len(embeddings_arrow), embedding_dim) + ) + labels = joined_keys.column("label").to_pandas() + unknown = set(labels.unique()) - set(class_indices.keys()) if unknown: raise ValueError( - f"labels in tiles not present in class_indices: {sorted(unknown)}" + f"labels in data not present in class_indices: {sorted(unknown)}" ) - - self.embeddings = np.stack(df["embedding"].tolist()).astype(np.float32) - self.labels = df["label"].map(class_indices).to_numpy(dtype=np.int64) - self.slide_ids = df["slide_id"].to_numpy() + self.labels = labels.map(class_indices).to_numpy(dtype=np.int64) + self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() def __len__(self) -> int: return len(self.labels) @@ -68,10 +106,67 @@ def __getitem__(self, idx: int) -> Sample: ) @staticmethod - def _load_parquet(path_or_uri: str | Path) -> pd.DataFrame: + def _filter_metadata( + metadata_uri: str | Path, + thresholds: dict[str, float], + tissue_prop_min: float, + include_folds: list[int] | None, + exclude_folds: list[int] | None, + ) -> pd.DataFrame: + local = EmbeddingTilesDataset._resolve_uri(metadata_uri) + df = pd.read_parquet(local) + + roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] + if not roi_cols: + raise ValueError( + "metadata parquet has no roi_coverage_* columns; cannot label" + ) + + classes_in_data = {c.removeprefix("roi_coverage_") for c in roi_cols} + missing_thresholds = classes_in_data - set(thresholds.keys()) + if missing_thresholds: + raise ValueError( + f"thresholds missing entries for classes present in data: " + f"{sorted(missing_thresholds)}" + ) + + tissue_prop = df[roi_cols].sum(axis=1).to_numpy() + df = df.loc[tissue_prop >= tissue_prop_min] + if df.empty: + raise RuntimeError("all tiles dropped by tissue_prop_min filter") + + nonzero_classes = (df[roi_cols].to_numpy() > 0).sum(axis=1) + df = df.loc[pd.Series(nonzero_classes <= 1, index=df.index)] + if df.empty: + raise RuntimeError("all tiles dropped by single-class filter") + + roi_only = df[roi_cols] + dominant = roi_only.idxmax(axis=1).str.removeprefix("roi_coverage_") + dominant_value = roi_only.max(axis=1).to_numpy() + threshold_per_row = dominant.map(thresholds).to_numpy() + keep = dominant_value >= threshold_per_row + df = df.loc[pd.Series(keep, index=df.index)].copy() + df["label"] = dominant.to_numpy()[keep] + if df.empty: + raise RuntimeError("all tiles dropped by per-class thresholds") + + if include_folds is not None or exclude_folds is not None: + if "fold" not in df.columns: + raise RuntimeError( + "fold filter requested but 'fold' column not in metadata" + ) + if include_folds is not None: + df = df[df["fold"].isin(include_folds)] + if exclude_folds is not None: + df = df[~df["fold"].isin(exclude_folds)] + if df.empty: + raise RuntimeError("all tiles dropped by fold filter") + + return df[["slide_id", "x", "y", "label"]] + + @staticmethod + def _resolve_uri(path_or_uri: str | Path) -> str: s = str(path_or_uri) if s.startswith(("mlflow-artifacts:/", "runs:/")): - local = download_artifacts(artifact_uri=s) - else: - local = s - return pd.read_parquet(local) + return download_artifacts(artifact_uri=s) + return s diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 1ee25753..0d2cc556 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -1,6 +1,6 @@ from collections import defaultdict from collections.abc import Iterable -from typing import Any +from typing import Any, cast import mlflow import numpy as np @@ -83,9 +83,11 @@ def __init__( def setup(self, stage: str) -> None: if stage == "fit": - labels = self.trainer.datamodule.train.labels + datamodule = cast(Any, self.trainer).datamodule + labels = datamodule.train.labels num_classes = len(self.class_names) counts = np.bincount(labels, minlength=num_classes).astype(float) + counts = np.maximum(counts, 1.0) weights = len(labels) / (num_classes * counts) self.criterion = nn.CrossEntropyLoss( weight=torch.tensor(weights, dtype=torch.float32) diff --git a/pyproject.toml b/pyproject.toml index bc54e933..0cff3865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,8 @@ dependencies = [ "tqdm>=4.66.0", "rationai-sdk", "ratiopath>=1.2.0", - "pyarrow>=19.0.0", - "datasets>=3.0.0", + "pyarrow>=19.0.1", + "datasets>=4.0.0", "scikit-learn>=1.8.0", "numpy>=2.3.5", "rationai-tiling>=1.1.1", @@ -32,8 +32,6 @@ dependencies = [ "timm>=1.0.0", "einops>=0.8.0", "matplotlib>=3.10.7", - "pyarrow>=19.0.1", - "datasets>=4.0.0", ] [dependency-groups] From bdce760107fd4fcba4ecdbadeff4d5f6003399ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 20:30:18 +0200 Subject: [PATCH 034/107] feat: add prints --- ml/data/datasets/embedding_tiles.py | 59 ++++++++++++++++++++++++++++- uv.lock | 2 - 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 6cf2c1df..bd86542b 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -5,6 +5,7 @@ load time to produce ``(embedding, class_index, slide_id)`` triples. """ +import time from pathlib import Path import numpy as np @@ -42,6 +43,15 @@ def __init__( include_folds: list[int] | None = None, exclude_folds: list[int] | None = None, ) -> None: + tag = ( + f"include_folds={include_folds}" + if include_folds is not None + else f"exclude_folds={exclude_folds}" + if exclude_folds is not None + else "no_folds" + ) + print(f"[dataset] init: {tag}, metadata={metadata_uri}", flush=True) + t0 = time.time() meta_df = self._filter_metadata( metadata_uri, thresholds, @@ -49,37 +59,73 @@ def __init__( include_folds, exclude_folds, ) + print( + f"[dataset] metadata filtered: {len(meta_df)} rows " + f"({time.time() - t0:.1f}s)", + flush=True, + ) + t = time.time() emb_dir = self._resolve_uri(embedding_uri) + print( + f"[dataset] embedding artifacts resolved in {time.time() - t:.1f}s", + flush=True, + ) + + t = time.time() emb_table = pads.dataset(emb_dir, format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) + print( + f"[dataset] embedding parquet loaded: {emb_table.num_rows} rows " + f"({time.time() - t:.1f}s)", + flush=True, + ) + t = time.time() emb_col = emb_table.column("embedding") if pa.types.is_list(emb_col.type): target_type = pa.large_list(emb_col.type.value_type) emb_col = pa.chunked_array( [c.cast(target_type) for c in emb_col.chunks], type=target_type ) + print( + f"[dataset] embedding column cast in {time.time() - t:.1f}s", flush=True + ) emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table + t = time.time() meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) del meta_df joined_keys = meta_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys, meta_table + print( + f"[dataset] arrow join: {joined_keys.num_rows} rows " + f"({time.time() - t:.1f}s)", + flush=True, + ) if joined_keys.num_rows == 0: raise RuntimeError("inner join with embeddings produced empty dataset") + t = time.time() indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - embeddings_arrow = emb_col.combine_chunks().take(indices) + emb_contig = emb_col.combine_chunks() + print( + f"[dataset] combine_chunks done in {time.time() - t:.1f}s", flush=True + ) + + t = time.time() + embeddings_arrow = emb_contig.take(indices) + print(f"[dataset] take done in {time.time() - t:.1f}s", flush=True) + t = time.time() embedding_dim = len(embeddings_arrow.values) // len(embeddings_arrow) self.embeddings = ( embeddings_arrow.values.to_numpy(zero_copy_only=False) @@ -94,6 +140,11 @@ def __init__( ) self.labels = labels.map(class_indices).to_numpy(dtype=np.int64) self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() + print( + f"[dataset] numpy conversion done in {time.time() - t:.1f}s, " + f"total={time.time() - t0:.1f}s", + flush=True, + ) def __len__(self) -> int: return len(self.labels) @@ -113,8 +164,14 @@ def _filter_metadata( include_folds: list[int] | None, exclude_folds: list[int] | None, ) -> pd.DataFrame: + t = time.time() local = EmbeddingTilesDataset._resolve_uri(metadata_uri) df = pd.read_parquet(local) + print( + f"[dataset] metadata parquet loaded: {len(df)} rows " + f"({time.time() - t:.1f}s)", + flush=True, + ) roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] if not roi_cols: diff --git a/uv.lock b/uv.lock index d884e875..c4ad7300 100644 --- a/uv.lock +++ b/uv.lock @@ -2325,7 +2325,6 @@ dev = [ [package.metadata] requires-dist = [ - { name = "datasets", specifier = ">=3.0.0" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "einops", specifier = ">=0.8.0" }, { name = "hydra-core", specifier = ">=1.3.2" }, @@ -2336,7 +2335,6 @@ requires-dist = [ { name = "omegaconf", specifier = ">=2.3.0" }, { name = "openslide-python", specifier = ">=1.4.2" }, { name = "pandas", specifier = ">=2.0.0" }, - { name = "pyarrow", specifier = ">=19.0.0" }, { name = "pyarrow", specifier = ">=19.0.1" }, { name = "rationai-masks" }, { name = "rationai-mlkit", git = "https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/mlkit.git" }, From ac633d5717e1419a950c5b0fd5ff9b0c61f0265c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 20:40:09 +0200 Subject: [PATCH 035/107] fix: use chunks --- ml/data/datasets/embedding_tiles.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index bd86542b..73421fa1 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -116,15 +116,17 @@ def __init__( indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - emb_contig = emb_col.combine_chunks() + # take first on chunked array (avoids combine_chunks on full 1.1M rows) + embeddings_arrow = emb_col.take(indices) + print(f"[dataset] take done in {time.time() - t:.1f}s", flush=True) + + t = time.time() + if isinstance(embeddings_arrow, pa.ChunkedArray): + embeddings_arrow = embeddings_arrow.combine_chunks() print( f"[dataset] combine_chunks done in {time.time() - t:.1f}s", flush=True ) - t = time.time() - embeddings_arrow = emb_contig.take(indices) - print(f"[dataset] take done in {time.time() - t:.1f}s", flush=True) - t = time.time() embedding_dim = len(embeddings_arrow.values) // len(embeddings_arrow) self.embeddings = ( From 2793562c4c98fa21cf5d7477e978f8f8060a1882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 20:54:17 +0200 Subject: [PATCH 036/107] fix: use numpy chunks --- ml/data/datasets/embedding_tiles.py | 47 +++++++++++++++++++---------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 73421fa1..601dd08e 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -113,27 +113,42 @@ def __init__( raise RuntimeError("inner join with embeddings produced empty dataset") t = time.time() - indices = joined_keys.column("_emb_idx") - if isinstance(indices, pa.ChunkedArray): - indices = indices.combine_chunks() - # take first on chunked array (avoids combine_chunks on full 1.1M rows) - embeddings_arrow = emb_col.take(indices) - print(f"[dataset] take done in {time.time() - t:.1f}s", flush=True) + _idx_col = joined_keys.column("_emb_idx") + if isinstance(_idx_col, pa.ChunkedArray): + _idx_col = _idx_col.combine_chunks() + indices_np = _idx_col.to_numpy() - t = time.time() - if isinstance(embeddings_arrow, pa.ChunkedArray): - embeddings_arrow = embeddings_arrow.combine_chunks() + first_chunk = emb_col.chunks[0] + embedding_dim = len(first_chunk.values) // len(first_chunk) + + # sort indices for sequential per-chunk access; restore order afterwards + sort_order = np.argsort(indices_np) + sorted_indices = indices_np[sort_order] + + chunk_offsets = np.concatenate( + [[0], np.cumsum([len(c) for c in emb_col.chunks])] + ) + embeddings = np.empty((len(indices_np), embedding_dim), dtype=np.float32) + for ci, chunk in enumerate(emb_col.chunks): + lo, hi = chunk_offsets[ci], chunk_offsets[ci + 1] + mask = (sorted_indices >= lo) & (sorted_indices < hi) + if not mask.any(): + continue + local_idx = sorted_indices[mask] - lo + chunk_np = ( + chunk.values.to_numpy(zero_copy_only=False) + .reshape(len(chunk), embedding_dim) + .astype(np.float32) + ) + embeddings[sort_order[mask]] = chunk_np[local_idx] + del emb_col print( - f"[dataset] combine_chunks done in {time.time() - t:.1f}s", flush=True + f"[dataset] chunk-wise extraction done in {time.time() - t:.1f}s", + flush=True, ) t = time.time() - embedding_dim = len(embeddings_arrow.values) // len(embeddings_arrow) - self.embeddings = ( - embeddings_arrow.values.to_numpy(zero_copy_only=False) - .astype(np.float32) - .reshape(len(embeddings_arrow), embedding_dim) - ) + self.embeddings = embeddings labels = joined_keys.column("label").to_pandas() unknown = set(labels.unique()) - set(class_indices.keys()) if unknown: From e81973eadb7f6a758f7fa7ee6eb1d16bb5cc1b91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 21:17:32 +0200 Subject: [PATCH 037/107] fix: call end at the end of the main --- ml/__main__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml/__main__.py b/ml/__main__.py index 61ef4008..d531a08b 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -1,6 +1,7 @@ from random import randint import hydra +import mlflow from lightning import seed_everything from omegaconf import DictConfig, OmegaConf from rationai.mlkit import Trainer, autolog @@ -25,6 +26,7 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: model = hydra.utils.instantiate(config.model, _target_=MetaArch) trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger) getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) + mlflow.end_run() if __name__ == "__main__": From 0071592f12c05301fa88ca3b22594b5ba4ca2938 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 22:12:05 +0200 Subject: [PATCH 038/107] chore: remove prints --- ml/data/datasets/embedding_tiles.py | 54 ----------------------------- ml/meta_arch.py | 6 +--- preprocessing/embedding_dataset.py | 24 ------------- scripts/submit_train_linear.py | 6 ++-- 4 files changed, 4 insertions(+), 86 deletions(-) diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 601dd08e..791a7342 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -5,7 +5,6 @@ load time to produce ``(embedding, class_index, slide_id)`` triples. """ -import time from pathlib import Path import numpy as np @@ -43,15 +42,6 @@ def __init__( include_folds: list[int] | None = None, exclude_folds: list[int] | None = None, ) -> None: - tag = ( - f"include_folds={include_folds}" - if include_folds is not None - else f"exclude_folds={exclude_folds}" - if exclude_folds is not None - else "no_folds" - ) - print(f"[dataset] init: {tag}, metadata={metadata_uri}", flush=True) - t0 = time.time() meta_df = self._filter_metadata( metadata_uri, thresholds, @@ -59,60 +49,32 @@ def __init__( include_folds, exclude_folds, ) - print( - f"[dataset] metadata filtered: {len(meta_df)} rows " - f"({time.time() - t0:.1f}s)", - flush=True, - ) - t = time.time() emb_dir = self._resolve_uri(embedding_uri) - print( - f"[dataset] embedding artifacts resolved in {time.time() - t:.1f}s", - flush=True, - ) - - t = time.time() emb_table = pads.dataset(emb_dir, format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) - print( - f"[dataset] embedding parquet loaded: {emb_table.num_rows} rows " - f"({time.time() - t:.1f}s)", - flush=True, - ) - t = time.time() emb_col = emb_table.column("embedding") if pa.types.is_list(emb_col.type): target_type = pa.large_list(emb_col.type.value_type) emb_col = pa.chunked_array( [c.cast(target_type) for c in emb_col.chunks], type=target_type ) - print( - f"[dataset] embedding column cast in {time.time() - t:.1f}s", flush=True - ) emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table - t = time.time() meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) del meta_df joined_keys = meta_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys, meta_table - print( - f"[dataset] arrow join: {joined_keys.num_rows} rows " - f"({time.time() - t:.1f}s)", - flush=True, - ) if joined_keys.num_rows == 0: raise RuntimeError("inner join with embeddings produced empty dataset") - t = time.time() _idx_col = joined_keys.column("_emb_idx") if isinstance(_idx_col, pa.ChunkedArray): _idx_col = _idx_col.combine_chunks() @@ -142,12 +104,7 @@ def __init__( ) embeddings[sort_order[mask]] = chunk_np[local_idx] del emb_col - print( - f"[dataset] chunk-wise extraction done in {time.time() - t:.1f}s", - flush=True, - ) - t = time.time() self.embeddings = embeddings labels = joined_keys.column("label").to_pandas() unknown = set(labels.unique()) - set(class_indices.keys()) @@ -157,11 +114,6 @@ def __init__( ) self.labels = labels.map(class_indices).to_numpy(dtype=np.int64) self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() - print( - f"[dataset] numpy conversion done in {time.time() - t:.1f}s, " - f"total={time.time() - t0:.1f}s", - flush=True, - ) def __len__(self) -> int: return len(self.labels) @@ -181,14 +133,8 @@ def _filter_metadata( include_folds: list[int] | None, exclude_folds: list[int] | None, ) -> pd.DataFrame: - t = time.time() local = EmbeddingTilesDataset._resolve_uri(metadata_uri) df = pd.read_parquet(local) - print( - f"[dataset] metadata parquet loaded: {len(df)} rows " - f"({time.time() - t:.1f}s)", - flush=True, - ) roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] if not roi_cols: diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 0d2cc556..7f09cc13 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -28,10 +28,6 @@ class MetaArch(LightningModule): ``nn.Identity`` and ``decode_head`` is a single ``nn.Linear``. """ - # TODO: support class_weights for CE loss when class distribution is heavily - # imbalanced. Inject either via config (list[float]) or compute from the - # train fold label distribution at setup(). - def __init__( self, backbone: nn.Module, @@ -83,7 +79,7 @@ def __init__( def setup(self, stage: str) -> None: if stage == "fit": - datamodule = cast(Any, self.trainer).datamodule + datamodule = cast("Any", self.trainer).datamodule labels = datamodule.train.labels num_classes = len(self.class_names) counts = np.bincount(labels, minlength=num_classes).astype(float) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 7669db9b..86091d4d 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,7 +8,6 @@ import shutil import tempfile -import time from pathlib import Path import hydra @@ -94,29 +93,19 @@ def join_embeddings( emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx - t = time.time() joined_keys = tiles_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys - print( - f"[join] arrow key-join: {time.time() - t:.1f}s rows={joined_keys.num_rows}", - flush=True, - ) indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - t = time.time() emb_contig = emb_col.combine_chunks() del emb_col - print(f"[join] combine_chunks: {time.time() - t:.1f}s", flush=True) - - t = time.time() embeddings = emb_contig.take(indices) del emb_contig - print(f"[join] take: {time.time() - t:.1f}s", flush=True) joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) dropped_no_embedding = tiles_table.num_rows - joined.num_rows @@ -133,7 +122,6 @@ def process_split( output_split_dir: Path, derive: bool, ) -> dict[str, int]: - print(f"[{split_name}] downloading source tiles", flush=True) src_local = mlflow.artifacts.download_artifacts( run_id=src_run_id, artifact_path=src_artifact_path ) @@ -182,11 +170,6 @@ def process_split( c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) ] df = df.drop(columns=drop_cols) - print( - f"[{split_name}] {input_count} -> {after_tissue_filter} (tissue) " - f"-> {after_class_threshold} (class threshold), joining embeddings", - flush=True, - ) tiles_table = pa.Table.from_pandas(df, preserve_index=False) del df @@ -195,12 +178,6 @@ def process_split( tiles_table, embedding_run_id, split_name ) del tiles_table - if dropped_no_embedding != 0: - print( - f"WARNING: {dropped_no_embedding} tiles in split '{split_name}' have " - "no matching embedding and were dropped on join.", - flush=True, - ) sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) merged_table = merged_table.take(sort_indices) @@ -214,7 +191,6 @@ def process_split( shutil.copy(slides_local, output_split_dir / "slides.parquet") log_label_distributions(split_name, merged_table) - print(f"[{split_name}] wrote {merged_table.num_rows} rows", flush=True) return { "input_count": input_count, diff --git a/scripts/submit_train_linear.py b/scripts/submit_train_linear.py index d19f6747..93cf6868 100644 --- a/scripts/submit_train_linear.py +++ b/scripts/submit_train_linear.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-train-linear", - username="vcifka", + username=..., cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", + "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=ml/linear_classifier val_fold=0,1,2,3,4 --multirun", + "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --multirun", ], storage=[storage.secure.PROJECTS], ) From c0a7499de7fa602e72a2eb30d1840f519d97e350 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 22:18:46 +0200 Subject: [PATCH 039/107] chore: remove debug prints, stale TODO, and unused preprocessing pipeline - Drop all print/logging/timing instrumentation from embedding_dataset.py and ml/data/datasets/embedding_tiles.py - Remove stale TODO comment in meta_arch.py (class_weights already implemented via setup() computing balanced weights from train fold label distribution) - Delete preprocessing/embedding_dataset.py and related configs/scripts (embedding dataset build pipeline not needed for this branch) - Add PR.md with title and description Co-Authored-By: Claude Sonnet 4.6 --- PR.md | 34 +++ .../preprocessing/embedding_dataset.yaml | 23 -- configs/preprocessing/embedding_dataset.yaml | 13 - preprocessing/embedding_dataset.py | 272 ------------------ scripts/submit_embedding_dataset.py | 18 -- 5 files changed, 34 insertions(+), 326 deletions(-) create mode 100644 PR.md delete mode 100644 configs/experiment/preprocessing/embedding_dataset.yaml delete mode 100644 configs/preprocessing/embedding_dataset.yaml delete mode 100644 preprocessing/embedding_dataset.py delete mode 100644 scripts/submit_embedding_dataset.py diff --git a/PR.md b/PR.md new file mode 100644 index 00000000..7c4eb762 --- /dev/null +++ b/PR.md @@ -0,0 +1,34 @@ +# feat: linear classifier training pipeline on precomputed embeddings + +## Summary + +Adds an end-to-end ML training pipeline for linear probing on precomputed tile +embeddings. Introduces the embedding dataset preprocessing step, a PyTorch +Lightning training module, and all supporting configs and submission scripts. + +## Changes + +### Preprocessing +- `preprocessing/_labels.py` — shared label/tissue-prop derivation logic. + +### ML training +- `ml/meta_arch.py` — `MetaArch` Lightning module: backbone + decode head + + CrossEntropyLoss with balanced class weights computed from the train fold. + Logs per-class metrics, confusion matrices, and per-slide accuracy. +- `ml/data/datasets/embedding_tiles.py` — `EmbeddingTilesDataset`: loads the + embedding parquet, inner-joins with metadata, and serves `(embedding, label, + slide_id)` triples. Stays in Arrow for the join to avoid large-list → pandas + conversion overhead. +- `ml/data/data_module.py` — Lightning `DataModule` wrapping train/val/test splits. +- `ml/callbacks/parquet_prediction_writer.py` — writes model predictions to Parquet. +- `configs/experiment/ml/linear_classifier.yaml` — full experiment config. +- `configs/ml/` — model, data, and trainer sub-configs. +- `scripts/submit_train_linear.py` — MLflow submission script. + +## Test plan + +- [ ] Run `submit_train_linear.py`; verify training converges and MLflow logs + loss, macro F1, per-class metrics, and confusion matrix figures. +- [ ] Check class weights are logged under `class_weight/` in MLflow. +- [ ] Confirm `parquet_prediction_writer` produces a valid predictions Parquet + on the test split. diff --git a/configs/experiment/preprocessing/embedding_dataset.yaml b/configs/experiment/preprocessing/embedding_dataset.yaml deleted file mode 100644 index 8004e2e4..00000000 --- a/configs/experiment/preprocessing/embedding_dataset.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# @package _global_ - -defaults: - - /data: dataset - - _self_ - -tissue_prop_min: 0.2 -thresholds: - Nerve: 0.0 - Blood: 0.0 - Connective-Tissue: 0.4 - Fat: 0.5 - Epithelium: 0.2 - Muscle: 0.4 - Other: 0.5 - -metadata: - run_name: Embedding dataset ${dataset.name} - description: "Join k-fold (${dataset.mlflow_artifacts.kfold_run_id}) and filter_tiles (${dataset.mlflow_artifacts.filter_tiles_run_id}) tile metadata with embeddings (${dataset.mlflow_artifacts.embedding_run_id})." - hyperparams: - kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} - filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} - embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} diff --git a/configs/preprocessing/embedding_dataset.yaml b/configs/preprocessing/embedding_dataset.yaml deleted file mode 100644 index f4af56a6..00000000 --- a/configs/preprocessing/embedding_dataset.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# @package _global_ - -mlflow_artifact_path: embedding_dataset - -tissue_prop_min: ??? -thresholds: ??? - -metadata: - run_name: "Embedding dataset ${dataset.name}" - description: "Build embedding training dataset by joining k-fold/filter_tiles tile metadata with precomputed embeddings." - hyperparams: - tissue_prop_min: ${tissue_prop_min} - thresholds: ${thresholds} diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py deleted file mode 100644 index 86091d4d..00000000 --- a/preprocessing/embedding_dataset.py +++ /dev/null @@ -1,272 +0,0 @@ -"""Build an embedding training dataset by joining tile metadata with embeddings. - -Joins precomputed tile embeddings with k-fold metadata (train) / filter_tiles -metadata (test), applies tissue + per-class ROI thresholds before the join, and -emits a training-ready Parquet dataset (per-split ``slides.parquet`` + -``tiles.parquet``) ready for ``rationai.mlkit.data.datasets.SlidesTilesLoader``. -""" - -import shutil -import tempfile -from pathlib import Path - -import hydra -import mlflow -import mlflow.artifacts -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.compute as pc -import pyarrow.dataset as pads -import pyarrow.parquet as pq -from omegaconf import DictConfig, OmegaConf -from rationai.mlkit import autolog, with_cli_args -from rationai.mlkit.lightning.loggers import MLFlowLogger - -from preprocessing._labels import compute_label_and_tissue_prop - - -def apply_thresholds( - df: pd.DataFrame, - tissue_prop_min: float, - thresholds: dict[str, float], - roi_cols: list[str], -) -> tuple[pd.DataFrame, int]: - """Filter tiles by tissue + per-class thresholds and rewrite labels. - - Filters ``df`` by ``tissue_prop_min``, then keeps tiles where ANY class - meets its threshold; among passing classes, the highest-coverage one - becomes the label. - - Returns ``(filtered_df, after_tissue_count)`` so the caller can log both - intermediate counts. The returned df has its ``label`` column rewritten to - reflect the argmax-over-passers rule. - """ - df = df.loc[df["tissue_prop"] >= tissue_prop_min] - after_tissue = len(df) - if df.empty: - return df, after_tissue - - class_names = np.array([c.removeprefix("roi_coverage_") for c in roi_cols]) - thr = np.array([thresholds[c] for c in class_names], dtype=float) - roi = df[roi_cols].to_numpy() - passes = roi >= thr - keep = passes.any(axis=1) - - masked = np.where(passes, roi, -np.inf) - label_idx = masked.argmax(axis=1) - new_labels = class_names[label_idx] - - out = df.loc[pd.Series(keep, index=df.index)].copy() - out["label"] = new_labels[keep] - return out, after_tissue - - -def join_embeddings( - tiles_table: pa.Table, - embedding_run_id: str, - embedding_split: str, -) -> tuple[pa.Table, int]: - """Join filtered tile metadata with embeddings on (slide_id, x, y). - - Stays in Arrow throughout to avoid the very slow list -> pandas - conversion. Acero's join engine doesn't accept list columns in non-key - fields, so we join on keys plus a synthetic row index, then pull embeddings - via take(). The embedding column is cast per chunk to large_list to avoid - int32 offset overflow that bites take() when chunks are concatenated. - """ - emb_dir = mlflow.artifacts.download_artifacts( - run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" - ) - emb_table = pads.dataset(emb_dir, format="parquet").to_table( - columns=["slide_id", "x", "y", "embedding"] - ) - - emb_col = emb_table.column("embedding") - if pa.types.is_list(emb_col.type): - target_type = pa.large_list(emb_col.type.value_type) - emb_col = pa.chunked_array( - [c.cast(target_type) for c in emb_col.chunks], type=target_type - ) - - emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) - emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) - del emb_table, emb_idx - - joined_keys = tiles_table.join( - emb_keys, keys=["slide_id", "x", "y"], join_type="inner" - ) - del emb_keys - - indices = joined_keys.column("_emb_idx") - if isinstance(indices, pa.ChunkedArray): - indices = indices.combine_chunks() - - emb_contig = emb_col.combine_chunks() - del emb_col - embeddings = emb_contig.take(indices) - del emb_contig - - joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) - dropped_no_embedding = tiles_table.num_rows - joined.num_rows - return joined, dropped_no_embedding - - -def process_split( - split_name: str, - src_run_id: str, - src_artifact_path: str, - embedding_run_id: str, - tissue_prop_min: float, - thresholds: dict[str, float], - output_split_dir: Path, - derive: bool, -) -> dict[str, int]: - src_local = mlflow.artifacts.download_artifacts( - run_id=src_run_id, artifact_path=src_artifact_path - ) - df = pads.dataset(src_local, format="parquet").to_table().to_pandas() - input_count = len(df) - - roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] - if not roi_cols: - raise RuntimeError( - f"No roi_coverage_* columns in {src_artifact_path}. " - "Cannot apply class thresholds." - ) - - classes_in_data = {c.removeprefix("roi_coverage_") for c in roi_cols} - missing = classes_in_data - set(thresholds.keys()) - if missing: - raise ValueError( - f"thresholds is missing entries for roi_coverage_* classes present " - f"in data: {sorted(missing)}" - ) - - if derive: - lbl, tp = compute_label_and_tissue_prop(df, roi_cols) - df["label"] = lbl - df["tissue_prop"] = tp - else: - required = {"label", "tissue_prop"} - missing_required = required - set(df.columns) - if missing_required: - raise RuntimeError( - f"Source split '{split_name}' (derive=False) is missing required " - f"columns {sorted(missing_required)} in {src_artifact_path}. " - "Expected the kfold_split artifact, which writes label/tissue_prop/fold." - ) - - df, after_tissue_filter = apply_thresholds( - df, tissue_prop_min, thresholds, roi_cols - ) - after_class_threshold = len(df) - if after_class_threshold == 0: - raise RuntimeError( - f"All {input_count} tiles dropped by thresholds for split '{split_name}'." - ) - - drop_cols = [ - c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) - ] - df = df.drop(columns=drop_cols) - - tiles_table = pa.Table.from_pandas(df, preserve_index=False) - del df - - merged_table, dropped_no_embedding = join_embeddings( - tiles_table, embedding_run_id, split_name - ) - del tiles_table - - sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) - merged_table = merged_table.take(sort_indices) - - output_split_dir.mkdir(parents=True, exist_ok=True) - pq.write_table(merged_table, str(output_split_dir / "tiles.parquet")) - - slides_local = mlflow.artifacts.download_artifacts( - run_id=embedding_run_id, artifact_path=f"{split_name}/slides.parquet" - ) - shutil.copy(slides_local, output_split_dir / "slides.parquet") - - log_label_distributions(split_name, merged_table) - - return { - "input_count": input_count, - "after_tissue_filter": after_tissue_filter, - "after_class_threshold": after_class_threshold, - "after_join": merged_table.num_rows, - "dropped_no_embedding": dropped_no_embedding, - } - - -def log_label_distributions(split_name: str, table: pa.Table) -> None: - has_fold = "fold" in table.schema.names - cols = ["label", "fold"] if has_fold else ["label"] - df = table.select(cols).to_pandas() - - label_dist = ( - df["label"].value_counts().rename_axis("label").reset_index(name="count") - ) - mlflow.log_table( - data=label_dist, - artifact_file=f"fold_statistics/{split_name}_label_distribution.json", - ) - - if has_fold: - fold_dist = ( - df.groupby(["fold", "label"]).size().unstack(fill_value=0).reset_index() - ) - mlflow.log_table( - data=fold_dist, - artifact_file=f"fold_statistics/{split_name}_fold_label_distribution.json", - ) - - -@with_cli_args(["+preprocessing=embedding_dataset"]) -@hydra.main(config_path="../configs", config_name="preprocessing", version_base=None) -@autolog -def main(config: DictConfig, logger: MLFlowLogger) -> None: - artifacts = config.dataset.mlflow_artifacts - kfold_run_id = artifacts.kfold_run_id - filter_tiles_run_id = artifacts.filter_tiles_run_id - embedding_run_id = artifacts.embedding_run_id - - tissue_prop_min = float(config.tissue_prop_min) - if tissue_prop_min <= 0: - raise ValueError( - f"tissue_prop_min must be > 0 (got {tissue_prop_min}); " - "otherwise background tiles are not filtered out." - ) - raw_thresholds = OmegaConf.to_container(config.thresholds, resolve=True) - if not isinstance(raw_thresholds, dict): - raise TypeError("config.thresholds must be a mapping of class -> threshold") - thresholds = {str(k): float(v) for k, v in raw_thresholds.items()} - - splits = [ - ("train", kfold_run_id, "kfold_split/kfold_tiles.parquet", False), - ("test", filter_tiles_run_id, "filter_tiles/test_tiles.parquet", True), - ] - - with tempfile.TemporaryDirectory() as tmp_root: - tmp_root_path = Path(tmp_root) - for split_name, src_run_id, src_artifact_path, derive in splits: - stats = process_split( - split_name=split_name, - src_run_id=src_run_id, - src_artifact_path=src_artifact_path, - embedding_run_id=embedding_run_id, - tissue_prop_min=tissue_prop_min, - thresholds=thresholds, - output_split_dir=tmp_root_path / split_name, - derive=derive, - ) - for key, value in stats.items(): - mlflow.log_metric(f"{split_name}_{key}", value) - - mlflow.log_artifacts(str(tmp_root_path), config.mlflow_artifact_path) - - -if __name__ == "__main__": - main() diff --git a/scripts/submit_embedding_dataset.py b/scripts/submit_embedding_dataset.py deleted file mode 100644 index 23977df5..00000000 --- a/scripts/submit_embedding_dataset.py +++ /dev/null @@ -1,18 +0,0 @@ -from kube_jobs import storage, submit_job - - -submit_job( - job_name="tissue-classification-embedding-dataset", - username=..., - cpu=8, - memory="64Gi", - gpu=None, - public=False, - script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - "uv run python -m preprocessing.embedding_dataset +experiment=...", - ], - storage=[storage.secure.PROJECTS], -) From fe918d152a75498982c4953e03b9ce81ce919de0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 22:25:56 +0200 Subject: [PATCH 040/107] chore: remove markdown file --- PR.md | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 PR.md diff --git a/PR.md b/PR.md deleted file mode 100644 index 7c4eb762..00000000 --- a/PR.md +++ /dev/null @@ -1,34 +0,0 @@ -# feat: linear classifier training pipeline on precomputed embeddings - -## Summary - -Adds an end-to-end ML training pipeline for linear probing on precomputed tile -embeddings. Introduces the embedding dataset preprocessing step, a PyTorch -Lightning training module, and all supporting configs and submission scripts. - -## Changes - -### Preprocessing -- `preprocessing/_labels.py` — shared label/tissue-prop derivation logic. - -### ML training -- `ml/meta_arch.py` — `MetaArch` Lightning module: backbone + decode head + - CrossEntropyLoss with balanced class weights computed from the train fold. - Logs per-class metrics, confusion matrices, and per-slide accuracy. -- `ml/data/datasets/embedding_tiles.py` — `EmbeddingTilesDataset`: loads the - embedding parquet, inner-joins with metadata, and serves `(embedding, label, - slide_id)` triples. Stays in Arrow for the join to avoid large-list → pandas - conversion overhead. -- `ml/data/data_module.py` — Lightning `DataModule` wrapping train/val/test splits. -- `ml/callbacks/parquet_prediction_writer.py` — writes model predictions to Parquet. -- `configs/experiment/ml/linear_classifier.yaml` — full experiment config. -- `configs/ml/` — model, data, and trainer sub-configs. -- `scripts/submit_train_linear.py` — MLflow submission script. - -## Test plan - -- [ ] Run `submit_train_linear.py`; verify training converges and MLflow logs - loss, macro F1, per-class metrics, and confusion matrix figures. -- [ ] Check class weights are logged under `class_weight/` in MLflow. -- [ ] Confirm `parquet_prediction_writer` produces a valid predictions Parquet - on the test split. From 6b7d1e8ec9bc2f45e5685ea03204416b018120d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 08:51:51 +0200 Subject: [PATCH 041/107] fix: edge cases --- ml/__main__.py | 3 +++ ml/callbacks/parquet_prediction_writer.py | 33 ++++++++--------------- ml/data/data_module.py | 5 +++- 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/ml/__main__.py b/ml/__main__.py index d531a08b..318c37c4 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -25,6 +25,9 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) model = hydra.utils.instantiate(config.model, _target_=MetaArch) trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger) + allowed_modes = ["fit", "test", "validate", "predict"] + if config.mode not in allowed_modes: + raise ValueError(f"Invalid mode {config.mode!r}. Allowed: {allowed_modes}") getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) mlflow.end_run() diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py index 1fe85580..38d86fd4 100644 --- a/ml/callbacks/parquet_prediction_writer.py +++ b/ml/callbacks/parquet_prediction_writer.py @@ -1,6 +1,5 @@ """Aggregate ``predict_step`` outputs and write them as a parquet artifact.""" -from collections.abc import Sequence from pathlib import Path from typing import Any @@ -22,19 +21,6 @@ class ParquetPredictionWriter(BasePredictionWriter): def __init__(self, output_filename: str = "predictions.parquet") -> None: super().__init__(write_interval="epoch") self.output_filename = output_filename - self._batches: list[dict[str, Any]] = [] - - def write_on_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - prediction: dict[str, Any], - batch_indices: Sequence[int] | None, - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - self._batches.append(prediction) def write_on_epoch_end( self, @@ -43,18 +29,23 @@ def write_on_epoch_end( predictions: Any, batch_indices: Any, ) -> None: - if not self._batches: + if trainer.global_rank != 0: return slide_ids: list[str] = [] targets: list[int] = [] preds: list[int] = [] probs: list[np.ndarray] = [] - for b in self._batches: - slide_ids.extend(b["slide_id"]) - targets.extend(b["target"].tolist()) - preds.extend(b["pred"].tolist()) - probs.append(b["probs"].numpy()) + for dataloader_preds in predictions: + for b in dataloader_preds: + slide_ids.extend(b["slide_id"]) + targets.extend(b["target"].tolist()) + preds.extend(b["pred"].tolist()) + probs.append(b["probs"].numpy()) + + if not slide_ids: + return + prob_matrix = np.concatenate(probs, axis=0) class_names = getattr(pl_module, "class_names", None) @@ -74,5 +65,3 @@ def write_on_epoch_end( active = mlflow.active_run() if active is not None: mlflow.log_artifact(str(out_path), artifact_path="predictions") - - self._batches.clear() diff --git a/ml/data/data_module.py b/ml/data/data_module.py index c39b9503..7302be54 100644 --- a/ml/data/data_module.py +++ b/ml/data/data_module.py @@ -33,7 +33,10 @@ def setup(self, stage: str) -> None: case "test": self.test = instantiate(self.datasets["test"]) case "predict": - self.predict = instantiate(self.datasets["predict"]) + dataset_cfg = self.datasets.get("predict") or self.datasets.get("test") + if dataset_cfg is None: + raise KeyError("Neither 'predict' nor 'test' dataset configured") + self.predict = instantiate(dataset_cfg) def train_dataloader(self) -> Iterable[Input]: return DataLoader( From 4ff988ef071f8ccb9fbacd4be701209db3d10c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 09:55:59 +0200 Subject: [PATCH 042/107] feat: normalize the confusion matrix rows per class recall --- ml/meta_arch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 7f09cc13..9ee227a8 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -204,8 +204,11 @@ def _log_per_slide_accuracy(self) -> None: def _confmat_figure( matrix: np.ndarray, class_names: Iterable[str], title: str ) -> Figure: + row_sums = matrix.sum(axis=1, keepdims=True) + normalized = np.divide(matrix, row_sums, where=row_sums > 0, out=np.zeros_like(matrix, dtype=float)) + fig, ax = plt.subplots(figsize=(6, 5)) - im = ax.imshow(matrix, cmap="Blues") + im = ax.imshow(normalized, cmap="Blues", vmin=0, vmax=1) ax.set_title(title) ax.set_xlabel("Predicted") ax.set_ylabel("True") @@ -222,7 +225,7 @@ def _confmat_figure( str(matrix[i, j]), ha="center", va="center", - color="white" if matrix[i, j] > matrix.max() / 2 else "black", + color="white" if normalized[i, j] > 0.5 else "black", fontsize=8, ) fig.colorbar(im, ax=ax) From 32375b27f5753cc17a018492f74b617c3578614f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 09:56:46 +0200 Subject: [PATCH 043/107] fix: format --- ml/meta_arch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 9ee227a8..04de7571 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -205,7 +205,9 @@ def _confmat_figure( matrix: np.ndarray, class_names: Iterable[str], title: str ) -> Figure: row_sums = matrix.sum(axis=1, keepdims=True) - normalized = np.divide(matrix, row_sums, where=row_sums > 0, out=np.zeros_like(matrix, dtype=float)) + normalized = np.divide( + matrix, row_sums, where=row_sums > 0, out=np.zeros_like(matrix, dtype=float) + ) fig, ax = plt.subplots(figsize=(6, 5)) im = ax.imshow(normalized, cmap="Blues", vmin=0, vmax=1) From af9538a09acde88a87627400e4511b743e658122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 15:13:42 +0200 Subject: [PATCH 044/107] feat: use stratified k fold run --- configs/data/dataset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 172e48fb..73926866 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,7 +14,7 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" - kfold_run_id: "850c81506684450b9af92296acfd045a" + kfold_run_id: "814611e8987d4d569b255b7a4749bc90" embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" From bc0819a1434324a96a4c4b21328ba47823f70746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 18:22:58 +0200 Subject: [PATCH 045/107] fix: remove criterion --- ml/meta_arch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 04de7571..07d14311 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -22,27 +22,27 @@ class MetaArch(LightningModule): - """Top-level classification architecture: backbone + decode_head + criterion. + """Top-level classification architecture: backbone + decode_head. For linear probing on precomputed embeddings, ``backbone`` is typically ``nn.Identity`` and ``decode_head`` is a single ``nn.Linear``. + Criterion is class-weighted CrossEntropyLoss, computed from training labels in setup(). """ def __init__( self, backbone: nn.Module, decode_head: nn.Module, - criterion: nn.Module, class_indices: dict[str, int], learning_rate: float = 1e-3, weight_decay: float = 0.0, ) -> None: super().__init__() - self.save_hyperparameters(ignore=["backbone", "decode_head", "criterion"]) + self.save_hyperparameters(ignore=["backbone", "decode_head"]) self.backbone = backbone self.decode_head = decode_head - self.criterion = criterion + self.criterion: nn.Module self.class_names = [ n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) From b8e85e0af5c6d7fb576b815c74befcf1687b71a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 18:27:33 +0200 Subject: [PATCH 046/107] fix: remove criterion from configs --- configs/ml/model/linear_classifier.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index dfff43ca..86c0ed25 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -9,9 +9,6 @@ model: in_features: 2560 out_features: ${len:${class_indices}} - criterion: - _target_: torch.nn.CrossEntropyLoss - class_indices: ${class_indices} learning_rate: 1.0e-3 From c387189f2d587ba7e24169022fa080b32ea151fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Wed, 13 May 2026 21:03:57 +0200 Subject: [PATCH 047/107] feat: implement test pipeline --- ml/__main__.py | 13 ++++++++++++- ml/meta_arch.py | 4 +++- scripts/submit_test_linear.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 scripts/submit_test_linear.py diff --git a/ml/__main__.py b/ml/__main__.py index 318c37c4..60effaa5 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -3,6 +3,7 @@ import hydra import mlflow from lightning import seed_everything +from mlflow.artifacts import download_artifacts from omegaconf import DictConfig, OmegaConf from rationai.mlkit import Trainer, autolog from rationai.mlkit.lightning.loggers import MLFlowLogger @@ -28,9 +29,19 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: allowed_modes = ["fit", "test", "validate", "predict"] if config.mode not in allowed_modes: raise ValueError(f"Invalid mode {config.mode!r}. Allowed: {allowed_modes}") - getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) + getattr(trainer, config.mode)( + model, datamodule=data, ckpt_path=_resolve_checkpoint(config.checkpoint) + ) mlflow.end_run() +def _resolve_checkpoint(checkpoint: str | None) -> str | None: + if checkpoint is None: + return None + if checkpoint.startswith(("mlflow-artifacts:/", "runs:/")): + return download_artifacts(artifact_uri=checkpoint) + return checkpoint + + if __name__ == "__main__": main() diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 07d14311..18a74803 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -90,6 +90,8 @@ def setup(self, stage: str) -> None: ) for cls, w in zip(self.class_names, weights.tolist(), strict=True): mlflow.log_metric(f"class_weight/{cls}", w) + else: + self.criterion = nn.CrossEntropyLoss() def forward(self, x: Tensor) -> Outputs: features = self.backbone(x) @@ -197,7 +199,7 @@ def _log_per_slide_accuracy(self) -> None: ] mlflow.log_table( data=pd.DataFrame(rows), - artifact_file="per_slide/test_tile_accuracy.json", + artifact_file="per_slide/test_tile_accuracy.parquet", ) diff --git a/scripts/submit_test_linear.py b/scripts/submit_test_linear.py new file mode 100644 index 00000000..58d5af29 --- /dev/null +++ b/scripts/submit_test_linear.py @@ -0,0 +1,31 @@ +from kube_jobs import storage, submit_job + + +fold_checkpoints = { + 0: "runs://checkpoints/.ckpt", + 1: "runs://checkpoints/.ckpt", + 2: "runs://checkpoints/.ckpt", + 3: "runs://checkpoints/.ckpt", + 4: "runs://checkpoints/.ckpt", +} + + +submit_job( + job_name="tissue-classification-test-linear", + username=..., + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + *[ + "uv run python -m ml +experiment=ml/linear_classifier " + f"mode=test val_fold={fold} checkpoint={checkpoint}" + for fold, checkpoint in fold_checkpoints.items() + ], + ], + storage=[storage.secure.PROJECTS], +) From 12165042abf62e24ab884f713f86118341b37bf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Wed, 13 May 2026 21:17:59 +0200 Subject: [PATCH 048/107] fix: Hydra unreached --- scripts/submit_test_linear.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/scripts/submit_test_linear.py b/scripts/submit_test_linear.py index 58d5af29..063952e2 100644 --- a/scripts/submit_test_linear.py +++ b/scripts/submit_test_linear.py @@ -2,28 +2,28 @@ fold_checkpoints = { - 0: "runs://checkpoints/.ckpt", - 1: "runs://checkpoints/.ckpt", - 2: "runs://checkpoints/.ckpt", - 3: "runs://checkpoints/.ckpt", - 4: "runs://checkpoints/.ckpt", + 0: "mlflow-artifacts:/104/26a6f9c741d543c9a09b54048be527a1/artifacts/checkpoints/epoch=3-val_loss=0.1973/checkpoint.ckpt", + 1: "mlflow-artifacts:/104/cc2be862324a446baffd9a8d90be604d/artifacts/checkpoints/epoch=1-val_loss=0.1218/checkpoint.ckpt", + 2: "mlflow-artifacts:/104/8454857b11984419bb7eae02a520ec71/artifacts/checkpoints/epoch=0-val_loss=0.2980/checkpoint.ckpt", + 3: "mlflow-artifacts:/104/bfa52277ea2744b9ab523c56a905dcda/artifacts/checkpoints/epoch=0-val_loss=1.0547/checkpoint.ckpt", + 4: "mlflow-artifacts:/104/358cd6ee286b4d67b7c12cf9bce0c3b4/artifacts/checkpoints/epoch=0-val_loss=0.1462/checkpoint.ckpt", } submit_job( job_name="tissue-classification-test-linear", - username=..., + username="vcifka", cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", + "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", *[ "uv run python -m ml +experiment=ml/linear_classifier " - f"mode=test val_fold={fold} checkpoint={checkpoint}" + f'mode=test val_fold={fold} checkpoint=\\"{checkpoint}\\"' for fold, checkpoint in fold_checkpoints.items() ], ], From 7ec86efd96af995a314c3d740d60a2f4cb1013f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Wed, 13 May 2026 21:26:46 +0200 Subject: [PATCH 049/107] fix: set weights only to false --- ml/__main__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ml/__main__.py b/ml/__main__.py index 60effaa5..7cf951cd 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -30,7 +30,10 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: if config.mode not in allowed_modes: raise ValueError(f"Invalid mode {config.mode!r}. Allowed: {allowed_modes}") getattr(trainer, config.mode)( - model, datamodule=data, ckpt_path=_resolve_checkpoint(config.checkpoint) + model, + datamodule=data, + ckpt_path=_resolve_checkpoint(config.checkpoint), + weights_only=False, ) mlflow.end_run() From c9b566ea08b2167953bf1c5b18ec381e37d13090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Wed, 13 May 2026 21:36:37 +0200 Subject: [PATCH 050/107] fix: criterion weight --- ml/meta_arch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 18a74803..11ba22a3 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -91,7 +91,9 @@ def setup(self, stage: str) -> None: for cls, w in zip(self.class_names, weights.tolist(), strict=True): mlflow.log_metric(f"class_weight/{cls}", w) else: - self.criterion = nn.CrossEntropyLoss() + self.criterion = nn.CrossEntropyLoss( + weight=torch.ones(len(self.class_names), dtype=torch.float32) + ) def forward(self, x: Tensor) -> Outputs: features = self.backbone(x) From 3cc670de01f98c9774f56b2c8bcf944ac76fee7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Wed, 13 May 2026 22:26:12 +0200 Subject: [PATCH 051/107] feat: add option to use different kfold strategies --- configs/data/dataset.yaml | 5 +++-- .../ml/linear_classifier_stratified_group_kfold.yaml | 8 ++++++++ .../experiment/ml/linear_classifier_stratified_kfold.yaml | 8 ++++++++ configs/{experiment => }/ml/linear_classifier.yaml | 8 +++++--- 4 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml create mode 100644 configs/experiment/ml/linear_classifier_stratified_kfold.yaml rename configs/{experiment => }/ml/linear_classifier.yaml (81%) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 73926866..0ab4e0d5 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,7 +14,8 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" - kfold_run_id: "814611e8987d4d569b255b7a4749bc90" + stratified_kfold_run_id: "814611e8987d4d569b255b7a4749bc90" + stratified_group_kfold_run_id: "382b41d2fa894514908e8067949c4326" embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" @@ -58,4 +59,4 @@ dataset: "44 Klarzelliges Nierenzellkarzinom_1": "kidney" "50 Muzinöses Zystadenom_1": "breast" "85 Mammakarzinom NST": "breast" - "28 Zöliakie": "small intestine" \ No newline at end of file + "28 Zöliakie": "small intestine" diff --git a/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml new file mode 100644 index 00000000..471f5a36 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/linear_classifier + - _self_ + +kfold_strategy: stratified_group +kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/linear_classifier_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_stratified_kfold.yaml new file mode 100644 index 00000000..c01fbbf9 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_stratified_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/linear_classifier + - _self_ + +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/experiment/ml/linear_classifier.yaml b/configs/ml/linear_classifier.yaml similarity index 81% rename from configs/experiment/ml/linear_classifier.yaml rename to configs/ml/linear_classifier.yaml index 2e9a9525..5606f0a5 100644 --- a/configs/experiment/ml/linear_classifier.yaml +++ b/configs/ml/linear_classifier.yaml @@ -11,7 +11,8 @@ defaults: mode: fit embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} -kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} train_embedding_uri: runs:/${embedding_run_id}/train/tiles @@ -34,10 +35,11 @@ thresholds: mlflow_artifact_path: linear_classifier metadata: - run_name: Linear Classifier ${dataset.name} fold=${val_fold} - description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), kfold metadata ${kfold_run_id}." + run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} + description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." hyperparams: embedding_run_id: ${embedding_run_id} + kfold_strategy: ${kfold_strategy} kfold_run_id: ${kfold_run_id} filter_tiles_run_id: ${filter_tiles_run_id} val_fold: ${val_fold} From ad0a4e7e714b8a17f6b61dddb024fb0bc7dd3e67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Wed, 13 May 2026 23:59:26 +0200 Subject: [PATCH 052/107] feat: add training without validation --- ml/callbacks/parquet_prediction_writer.py | 8 +- ml/data/datasets/embedding_tiles.py | 6 +- ml/meta_arch.py | 114 +++++++++++++++++++++- ml/typing.py | 4 +- 4 files changed, 124 insertions(+), 8 deletions(-) diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py index 38d86fd4..42e8dcc3 100644 --- a/ml/callbacks/parquet_prediction_writer.py +++ b/ml/callbacks/parquet_prediction_writer.py @@ -33,12 +33,16 @@ def write_on_epoch_end( return slide_ids: list[str] = [] + xs: list[int] = [] + ys: list[int] = [] targets: list[int] = [] preds: list[int] = [] probs: list[np.ndarray] = [] for dataloader_preds in predictions: for b in dataloader_preds: slide_ids.extend(b["slide_id"]) + xs.extend(b["x"].tolist()) + ys.extend(b["y"].tolist()) targets.extend(b["target"].tolist()) preds.extend(b["pred"].tolist()) probs.append(b["probs"].numpy()) @@ -55,7 +59,9 @@ def write_on_epoch_end( else [f"prob_{i}" for i in range(prob_matrix.shape[1])] ) - df = pd.DataFrame({"slide_id": slide_ids, "target": targets, "pred": preds}) + df = pd.DataFrame( + {"slide_id": slide_ids, "x": xs, "y": ys, "target": targets, "pred": preds} + ) df = pd.concat([df, pd.DataFrame(prob_matrix, columns=prob_columns)], axis=1) out_path = Path(trainer.default_root_dir) / self.output_filename diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 791a7342..a2474a5a 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -2,7 +2,7 @@ Joins precomputed tile embeddings with tile metadata (k-fold parquet for train, filter_tiles parquet for test) and applies tissue + per-class thresholds at -load time to produce ``(embedding, class_index, slide_id)`` triples. +load time to produce ``(embedding, class_index, slide_id, x, y)`` samples. """ from pathlib import Path @@ -114,6 +114,8 @@ def __init__( ) self.labels = labels.map(class_indices).to_numpy(dtype=np.int64) self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() + self.xs = joined_keys.column("x").to_pandas().to_numpy(dtype=np.int64) + self.ys = joined_keys.column("y").to_pandas().to_numpy(dtype=np.int64) def __len__(self) -> int: return len(self.labels) @@ -123,6 +125,8 @@ def __getitem__(self, idx: int) -> Sample: torch.from_numpy(self.embeddings[idx]), int(self.labels[idx]), str(self.slide_ids[idx]), + int(self.xs[idx]), + int(self.ys[idx]), ) @staticmethod diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 11ba22a3..a016dc47 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -1,5 +1,7 @@ from collections import defaultdict from collections.abc import Iterable +from pathlib import Path +from re import sub from typing import Any, cast import mlflow @@ -21,6 +23,9 @@ from ml.typing import Input, Outputs +MAX_TEST_PREDICTION_MAPS = 20 + + class MetaArch(LightningModule): """Top-level classification architecture: backbone + decode_head. @@ -76,6 +81,7 @@ def __init__( self._test_slide_correct: dict[str, int] = defaultdict(int) self._test_slide_total: dict[str, int] = defaultdict(int) + self._test_tile_rows: list[dict[str, Any]] = [] def setup(self, stage: str) -> None: if stage == "fit": @@ -100,14 +106,14 @@ def forward(self, x: Tensor) -> Outputs: return self.decode_head(features) def training_step(self, batch: Input, batch_idx: int) -> Tensor: - inputs, targets, _ = batch + inputs, targets, *_ = batch outputs = self(inputs) loss = self.criterion(outputs, targets) self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch: Input, batch_idx: int) -> None: - inputs, targets, _ = batch + inputs, targets, *_ = batch outputs = self(inputs) loss = self.criterion(outputs, targets) self.log("validation/loss", loss, on_epoch=True, prog_bar=True) @@ -121,7 +127,7 @@ def on_validation_epoch_end(self) -> None: self._log_confmat(self.val_confmat, "validation") def test_step(self, batch: Input, batch_idx: int) -> None: - inputs, targets, slide_ids = batch + inputs, targets, slide_ids, xs, ys = batch outputs = self(inputs) self.test_metrics.update(outputs, targets) self.test_per_class.update(outputs, targets) @@ -133,23 +139,44 @@ def test_step(self, batch: Input, batch_idx: int) -> None: for slide_id, ok in zip(slide_ids, correct, strict=True): self._test_slide_correct[slide_id] += int(ok) self._test_slide_total[slide_id] += 1 + self._test_tile_rows.extend( + { + "slide_id": slide_id, + "x": int(x), + "y": int(y), + "target": int(target), + "pred": int(pred), + } + for slide_id, x, y, target, pred in zip( + slide_ids, + xs.cpu().tolist(), + ys.cpu().tolist(), + targets.cpu().tolist(), + preds.cpu().tolist(), + strict=True, + ) + ) def on_test_epoch_end(self) -> None: self._log_per_class(self.test_per_class, "test") self._log_confmat(self.test_confmat, "test") self._log_per_slide_accuracy() + self._log_prediction_maps() self._test_slide_correct.clear() self._test_slide_total.clear() + self._test_tile_rows.clear() def predict_step( self, batch: Input, batch_idx: int, dataloader_idx: int = 0 ) -> dict[str, Any]: - inputs, targets, slide_ids = batch + inputs, targets, slide_ids, xs, ys = batch outputs = self(inputs) probs = outputs.softmax(dim=1) preds = outputs.argmax(dim=1) return { "slide_id": list(slide_ids), + "x": xs.cpu(), + "y": ys.cpu(), "target": targets.cpu(), "pred": preds.cpu(), "probs": probs.cpu(), @@ -204,6 +231,27 @@ def _log_per_slide_accuracy(self) -> None: artifact_file="per_slide/test_tile_accuracy.parquet", ) + def _log_prediction_maps(self) -> None: + if not self._test_tile_rows: + return + df = pd.DataFrame(self._test_tile_rows) + df["_correct"] = df["pred"] == df["target"] + slide_order = ( + df.groupby("slide_id", sort=False)["_correct"] + .mean() + .sort_values() + .head(MAX_TEST_PREDICTION_MAPS) + .index + ) + for slide_id in slide_order: + slide_df = df[df["slide_id"] == slide_id] + fig = _prediction_map_figure(slide_df, self.class_names) + artifact_file = f"prediction_maps/{_safe_filename(slide_id)}.png" + try: + mlflow.log_figure(fig, artifact_file=artifact_file) + finally: + plt.close(fig) + def _confmat_figure( matrix: np.ndarray, class_names: Iterable[str], title: str @@ -237,3 +285,61 @@ def _confmat_figure( fig.colorbar(im, ax=ax) fig.tight_layout() return fig + + +def _prediction_map_figure(df: pd.DataFrame, class_names: list[str]) -> Figure: + fig, axes = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True) + palette = plt.get_cmap("tab10", len(class_names)) + + axes[0].scatter( + df["x"], + df["y"], + c=df["pred"], + cmap=palette, + vmin=-0.5, + vmax=len(class_names) - 0.5, + marker="s", + s=4, + linewidths=0, + ) + axes[0].set_title("Prediction") + + correct = df["pred"].to_numpy() == df["target"].to_numpy() + axes[1].scatter( + df["x"], + df["y"], + c=np.where(correct, 0, 1), + cmap=plt.get_cmap("Set1", 2), + vmin=-0.5, + vmax=1.5, + marker="s", + s=4, + linewidths=0, + ) + axes[1].set_title(f"Errors ({int((~correct).sum())}/{len(df)})") + + handles = [ + plt.Line2D( + [0], + [0], + marker="s", + color="w", + label=cls, + markerfacecolor=palette(i), + markersize=6, + ) + for i, cls in enumerate(class_names) + ] + axes[0].legend(handles=handles, loc="upper left", bbox_to_anchor=(1.02, 1.0)) + + for ax in axes: + ax.set_aspect("equal", adjustable="box") + ax.invert_yaxis() + ax.set_xlabel("x") + ax.set_ylabel("y") + + return fig + + +def _safe_filename(value: str) -> str: + return sub(r"[^A-Za-z0-9_.-]+", "_", Path(value).stem or value) diff --git a/ml/typing.py b/ml/typing.py index 7060f5e4..5ad219c8 100644 --- a/ml/typing.py +++ b/ml/typing.py @@ -1,6 +1,6 @@ import torch -type Sample = tuple[torch.Tensor, int, str] -type Input = tuple[torch.Tensor, torch.Tensor, list[str]] +type Sample = tuple[torch.Tensor, int, str, int, int] +type Input = tuple[torch.Tensor, torch.Tensor, list[str], torch.Tensor, torch.Tensor] type Outputs = torch.Tensor From 811e21c69b4e2128ec98294bc07d4cfc49eaf559 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 00:00:11 +0200 Subject: [PATCH 053/107] feat: implement final test run --- .../ml/linear_classifier_final.yaml | 46 +++++++++++++++++++ configs/ml/data/embedding_final.yaml | 21 +++++++++ configs/ml/trainer/final.yaml | 23 ++++++++++ ml/data/data_module.py | 10 +++- scripts/submit_test_linear_final.py | 23 ++++++++++ scripts/submit_train_linear_final.py | 18 ++++++++ 6 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 configs/experiment/ml/linear_classifier_final.yaml create mode 100644 configs/ml/data/embedding_final.yaml create mode 100644 configs/ml/trainer/final.yaml create mode 100644 scripts/submit_test_linear_final.py create mode 100644 scripts/submit_train_linear_final.py diff --git a/configs/experiment/ml/linear_classifier_final.yaml b/configs/experiment/ml/linear_classifier_final.yaml new file mode 100644 index 00000000..05daafd0 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_final.yaml @@ -0,0 +1,46 @@ +# @package _global_ + +defaults: + - /data: dataset + - /class_mapping: collapse_alterations_to_other + - /ml/trainer: final + - /ml/data: embedding_final + - /ml/model: linear_classifier + - _self_ + +mode: fit + +embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} +kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} +filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + +train_embedding_uri: runs:/${embedding_run_id}/train/tiles +test_embedding_uri: runs:/${embedding_run_id}/test/tiles +train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet +test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet + +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.4 + Fat: 0.6 + Epithelium: 0.2 + Muscle: 0.5 + Other: 0.5 + +mlflow_artifact_path: linear_classifier_final + +metadata: + run_name: Final Linear Classifier ${dataset.name} + description: "Final linear probe over frozen Virchow2 embeddings trained on all training folds for ${trainer.max_epochs} epochs." + hyperparams: + embedding_run_id: ${embedding_run_id} + kfold_run_id: ${kfold_run_id} + filter_tiles_run_id: ${filter_tiles_run_id} + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} + learning_rate: ${model.learning_rate} + weight_decay: ${model.weight_decay} + batch_size: ${data.batch_size} + max_epochs: ${trainer.max_epochs} diff --git a/configs/ml/data/embedding_final.yaml b/configs/ml/data/embedding_final.yaml new file mode 100644 index 00000000..a1f6ca80 --- /dev/null +++ b/configs/ml/data/embedding_final.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +data: + batch_size: 1024 + num_workers: 4 + + train: + _target_: ml.data.datasets.EmbeddingTilesDataset + embedding_uri: ${train_embedding_uri} + metadata_uri: ${train_metadata_uri} + class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} + + test: + _target_: ml.data.datasets.EmbeddingTilesDataset + embedding_uri: ${test_embedding_uri} + metadata_uri: ${test_metadata_uri} + class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} diff --git a/configs/ml/trainer/final.yaml b/configs/ml/trainer/final.yaml new file mode 100644 index 00000000..33db2b99 --- /dev/null +++ b/configs/ml/trainer/final.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +trainer: + max_epochs: 6 + accelerator: auto + devices: auto + precision: 32 + log_every_n_steps: 50 + deterministic: false + + callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + save_last: true + save_top_k: 0 + filename: "final-epoch={epoch}" + auto_insert_metric_name: false + lr_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch + prediction_writer: + _target_: ml.callbacks.ParquetPredictionWriter + output_filename: predictions.parquet diff --git a/ml/data/data_module.py b/ml/data/data_module.py index 7302be54..b1275f9f 100644 --- a/ml/data/data_module.py +++ b/ml/data/data_module.py @@ -27,7 +27,11 @@ def setup(self, stage: str) -> None: match stage: case "fit": self.train = instantiate(self.datasets["train"]) - self.val = instantiate(self.datasets["val"]) + self.val = ( + instantiate(self.datasets["val"]) + if self.datasets.get("val") is not None + else None + ) case "validate": self.val = instantiate(self.datasets["val"]) case "test": @@ -48,7 +52,9 @@ def train_dataloader(self) -> Iterable[Input]: persistent_workers=self.num_workers > 0, ) - def val_dataloader(self) -> Iterable[Input]: + def val_dataloader(self) -> Iterable[Input] | None: + if self.val is None: + return None return DataLoader( self.val, batch_size=self.batch_size, diff --git a/scripts/submit_test_linear_final.py b/scripts/submit_test_linear_final.py new file mode 100644 index 00000000..8176035a --- /dev/null +++ b/scripts/submit_test_linear_final.py @@ -0,0 +1,23 @@ +from kube_jobs import storage, submit_job + + +checkpoint = ( + "mlflow-artifacts:/104//artifacts/checkpoints/last/checkpoint.ckpt" +) + + +submit_job( + job_name="tissue-classification-test-linear-final", + username="vcifka", + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + f'uv run python -m ml +experiment=ml/linear_classifier_final mode=test checkpoint=\\"{checkpoint}\\"', + ], + storage=[storage.secure.PROJECTS], +) diff --git a/scripts/submit_train_linear_final.py b/scripts/submit_train_linear_final.py new file mode 100644 index 00000000..def02bf1 --- /dev/null +++ b/scripts/submit_train_linear_final.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-linear-final", + username="vcifka", + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml +experiment=ml/linear_classifier_final", + ], + storage=[storage.secure.PROJECTS], +) From 27ceea344ca0d9ce230864ef1d83301e533273ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 00:05:41 +0200 Subject: [PATCH 054/107] fix: lower LR and patience --- configs/ml/model/linear_classifier.yaml | 2 +- configs/ml/trainer/default.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index 86c0ed25..fc9bfedd 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -11,5 +11,5 @@ model: class_indices: ${class_indices} - learning_rate: 1.0e-3 + learning_rate: 1.0e-4 weight_decay: 0.0 diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml index cf5766dc..a465a5cf 100644 --- a/configs/ml/trainer/default.yaml +++ b/configs/ml/trainer/default.yaml @@ -13,7 +13,7 @@ trainer: _target_: lightning.pytorch.callbacks.EarlyStopping monitor: validation/loss mode: min - patience: 5 + patience: 2 model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint monitor: validation/loss From efde82ab3d76965011dc1a2335b8a6e84b9ca132 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 13:26:57 +0200 Subject: [PATCH 055/107] fix: use f1 macro as a monitor --- configs/ml/trainer/default.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml index a465a5cf..63e3c150 100644 --- a/configs/ml/trainer/default.yaml +++ b/configs/ml/trainer/default.yaml @@ -11,15 +11,15 @@ trainer: callbacks: early_stopping: _target_: lightning.pytorch.callbacks.EarlyStopping - monitor: validation/loss - mode: min + monitor: validation/f1_macro + mode: max patience: 2 model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint - monitor: validation/loss - mode: min + monitor: validation/f1_macro + mode: max save_top_k: 1 - filename: "epoch={epoch}-val_loss={validation/loss:.4f}" + filename: "epoch={epoch}-val_f1={validation/f1_macro:.4f}" auto_insert_metric_name: false lr_monitor: _target_: lightning.pytorch.callbacks.LearningRateMonitor From c8102de253e7c1c19baf530881f7415b8a80b1dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 15:22:12 +0200 Subject: [PATCH 056/107] fix: rever back to validation loss --- configs/ml/trainer/default.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml index 63e3c150..a465a5cf 100644 --- a/configs/ml/trainer/default.yaml +++ b/configs/ml/trainer/default.yaml @@ -11,15 +11,15 @@ trainer: callbacks: early_stopping: _target_: lightning.pytorch.callbacks.EarlyStopping - monitor: validation/f1_macro - mode: max + monitor: validation/loss + mode: min patience: 2 model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint - monitor: validation/f1_macro - mode: max + monitor: validation/loss + mode: min save_top_k: 1 - filename: "epoch={epoch}-val_f1={validation/f1_macro:.4f}" + filename: "epoch={epoch}-val_loss={validation/loss:.4f}" auto_insert_metric_name: false lr_monitor: _target_: lightning.pytorch.callbacks.LearningRateMonitor From c5bab90c0dadc9ddd74fdf55101fbefeaae966ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 17:25:36 +0200 Subject: [PATCH 057/107] fix: add weight decay 1e-3 to linear classifier Train loss ~0.02 vs val loss ~0.32 indicated severe overfit on the linear probe. AdamW weight_decay was 0; bump to 1e-3 to regularize the head. --- configs/ml/model/linear_classifier.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index fc9bfedd..d5558048 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -12,4 +12,4 @@ model: class_indices: ${class_indices} learning_rate: 1.0e-4 - weight_decay: 0.0 + weight_decay: 1.0e-3 From 475b67c7a05af8bd5ab071182cf9cafb81c43498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 19:05:00 +0200 Subject: [PATCH 058/107] Revert "fix: add weight decay 1e-3 to linear classifier" This reverts commit c5bab90c0dadc9ddd74fdf55101fbefeaae966ee. --- configs/ml/model/linear_classifier.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index d5558048..fc9bfedd 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -12,4 +12,4 @@ model: class_indices: ${class_indices} learning_rate: 1.0e-4 - weight_decay: 1.0e-3 + weight_decay: 0.0 From 43663a9326e5d54dbf72335ee1d0d423b53bcda3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 20:28:59 +0200 Subject: [PATCH 059/107] feat: add logistic regression --- ...tic_regression_stratified_group_kfold.yaml | 8 + ..._logistic_regression_stratified_kfold.yaml | 8 + configs/ml/lbfgs_logistic_regression.yaml | 64 +++++++ ml/__main__.py | 7 + ml/sklearn_linear.py | 176 ++++++++++++++++++ 5 files changed, 263 insertions(+) create mode 100644 configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml create mode 100644 configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml create mode 100644 configs/ml/lbfgs_logistic_regression.yaml create mode 100644 ml/sklearn_linear.py diff --git a/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml b/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml new file mode 100644 index 00000000..49c954e0 --- /dev/null +++ b/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/lbfgs_logistic_regression + - _self_ + +kfold_strategy: stratified_group +kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml b/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml new file mode 100644 index 00000000..3d15556d --- /dev/null +++ b/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/lbfgs_logistic_regression + - _self_ + +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/ml/lbfgs_logistic_regression.yaml b/configs/ml/lbfgs_logistic_regression.yaml new file mode 100644 index 00000000..c6ef9997 --- /dev/null +++ b/configs/ml/lbfgs_logistic_regression.yaml @@ -0,0 +1,64 @@ +# @package _global_ + +defaults: + - /data: dataset + - /class_mapping: collapse_alterations_to_other + - /ml/data: embedding + - _self_ + +mode: fit +runner: sklearn_linear + +embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} +filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + +train_embedding_uri: runs:/${embedding_run_id}/train/tiles +test_embedding_uri: runs:/${embedding_run_id}/test/tiles +train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet +test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet + +val_fold: 0 + +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.4 + Fat: 0.6 + Epithelium: 0.2 + Muscle: 0.5 + Other: 0.5 + +model: + solver: lbfgs + penalty: l2 + C: 1.0 + class_weight: balanced + max_iter: 1000 + tol: 1.0e-4 + n_jobs: null + verbose: 0 + standardize: true + +mlflow_artifact_path: lbfgs_logistic_regression + +metadata: + run_name: LBFGS Logistic Regression ${dataset.name} ${kfold_strategy} fold=${val_fold} + description: "LBFGS multinomial logistic regression over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." + hyperparams: + embedding_run_id: ${embedding_run_id} + kfold_strategy: ${kfold_strategy} + kfold_run_id: ${kfold_run_id} + filter_tiles_run_id: ${filter_tiles_run_id} + val_fold: ${val_fold} + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} + solver: ${model.solver} + penalty: ${model.penalty} + C: ${model.C} + class_weight: ${model.class_weight} + max_iter: ${model.max_iter} + tol: ${model.tol} + standardize: ${model.standardize} diff --git a/ml/__main__.py b/ml/__main__.py index 318c37c4..620c882f 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -20,6 +20,13 @@ @hydra.main(config_path="../configs", config_name="ml", version_base=None) @autolog def main(config: DictConfig, logger: MLFlowLogger) -> None: + if config.get("runner") == "sklearn_linear": + from ml.sklearn_linear import run as run_sklearn_linear + + run_sklearn_linear(config) + mlflow.end_run() + return + seed_everything(config.seed, workers=True) data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) diff --git a/ml/sklearn_linear.py b/ml/sklearn_linear.py new file mode 100644 index 00000000..f8fb5368 --- /dev/null +++ b/ml/sklearn_linear.py @@ -0,0 +1,176 @@ +from pathlib import Path +from random import randint +from typing import Any + +import hydra +import joblib +import mlflow +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import autolog +from rationai.mlkit.lightning.loggers import MLFlowLogger +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import confusion_matrix, f1_score, recall_score +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler + +from ml.data import DataModule +from ml.meta_arch import _confmat_figure + + +if not OmegaConf.has_resolver("random_seed"): + OmegaConf.register_new_resolver( + "random_seed", lambda: randint(0, 2**31), use_cache=True + ) +if not OmegaConf.has_resolver("len"): + OmegaConf.register_new_resolver("len", lambda x: len(x)) + + +@hydra.main(config_path="../configs", config_name="ml", version_base=None) +@autolog +def main(config: DictConfig, logger: MLFlowLogger) -> None: + run(config) + mlflow.end_run() + + +def run(config: DictConfig) -> None: + if config.mode != "fit": + raise ValueError("sklearn_linear currently supports only mode='fit'") + + np.random.seed(config.seed) + data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) + data.setup("fit") + + x_train = data.train.embeddings + y_train = data.train.labels + x_val = data.val.embeddings + y_val = data.val.labels + + class_names = [ + n for n, _ in sorted(config.class_indices.items(), key=lambda kv: kv[1]) + ] + model = _build_model(config) + model.fit(x_train, y_train) + + _log_split_metrics( + model, x_val, y_val, data.val.slide_ids, class_names, "validation" + ) + _log_model(model) + + mlflow.log_params( + { + "model_type": "sklearn_logistic_regression", + "solver": config.model.solver, + "penalty": config.model.penalty, + "C": config.model.C, + "max_iter": config.model.max_iter, + "tol": config.model.tol, + "class_weight": config.model.class_weight, + "standardize": config.model.standardize, + "train_tiles": len(y_train), + "validation_tiles": len(y_val), + } + ) + + +def _build_model(config: DictConfig) -> Pipeline: + steps: list[tuple[str, Any]] = [] + if config.model.standardize: + steps.append(("scaler", StandardScaler())) + steps.append( + ( + "classifier", + LogisticRegression( + C=config.model.C, + class_weight=config.model.class_weight, + max_iter=config.model.max_iter, + n_jobs=config.model.n_jobs, + penalty=config.model.penalty, + random_state=config.seed, + solver=config.model.solver, + tol=config.model.tol, + verbose=config.model.verbose, + ), + ) + ) + return Pipeline(steps) + + +def _log_split_metrics( + model: Pipeline, + inputs: np.ndarray, + targets: np.ndarray, + slide_ids: np.ndarray, + class_names: list[str], + split: str, +) -> None: + labels = np.arange(len(class_names)) + preds = model.predict(inputs) + probs = _predict_proba_for_all_classes(model, inputs, labels) + + mlflow.log_metric( + f"{split}/acc_macro", + recall_score(targets, preds, labels=labels, average="macro", zero_division=0), + ) + mlflow.log_metric( + f"{split}/f1_macro", + f1_score(targets, preds, average="macro", zero_division=0), + ) + + per_class_acc = recall_score( + targets, preds, labels=labels, average=None, zero_division=0 + ) + per_class_f1 = f1_score( + targets, preds, labels=labels, average=None, zero_division=0 + ) + for cls_name, acc, f1 in zip( + class_names, per_class_acc.tolist(), per_class_f1.tolist(), strict=True + ): + mlflow.log_metric(f"{split}/acc_per_class/{cls_name}", acc) + mlflow.log_metric(f"{split}/f1_per_class/{cls_name}", f1) + + matrix = confusion_matrix(targets, preds, labels=labels) + fig = _confmat_figure(matrix, class_names, title=f"{split} confmat") + try: + mlflow.log_figure(fig, artifact_file=f"confusion_matrix/{split}.png") + finally: + plt.close(fig) + + prob_columns = [f"prob_{c}" for c in class_names] + predictions = pd.DataFrame( + { + "slide_id": slide_ids, + "target": targets, + "pred": preds, + } + ) + predictions = pd.concat( + [predictions, pd.DataFrame(probs, columns=prob_columns)], axis=1 + ) + out_path = Path(f"{split}_predictions.parquet") + predictions.to_parquet(out_path, index=False) + mlflow.log_artifact(str(out_path), artifact_path="predictions") + + +def _predict_proba_for_all_classes( + model: Pipeline, inputs: np.ndarray, labels: np.ndarray +) -> np.ndarray: + raw_probs = model.predict_proba(inputs) + probs = np.zeros((len(inputs), len(labels)), dtype=raw_probs.dtype) + for source_idx, class_idx in enumerate(model.classes_): + matching = np.flatnonzero(labels == class_idx) + if len(matching) == 1: + probs[:, matching[0]] = raw_probs[:, source_idx] + return probs + + +def _log_model(model: Pipeline) -> None: + out_path = Path("model.joblib") + joblib.dump(model, out_path) + mlflow.log_artifact(str(out_path), artifact_path="model") + + +if __name__ == "__main__": + main() From a2fe451b122ac4a856557f3f31691121b05f9da8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 20:53:54 +0200 Subject: [PATCH 060/107] feat: polish and add two distinct submission scripts --- configs/ml/lbfgs_logistic_regression.yaml | 2 +- ml/sklearn_linear.py | 80 ++++++++++++++----- ...linear.py => submit_train_linear_probe.py} | 9 ++- scripts/submit_train_logistic_regression.py | 24 ++++++ 4 files changed, 94 insertions(+), 21 deletions(-) rename scripts/{submit_train_linear.py => submit_train_linear_probe.py} (61%) create mode 100644 scripts/submit_train_logistic_regression.py diff --git a/configs/ml/lbfgs_logistic_regression.yaml b/configs/ml/lbfgs_logistic_regression.yaml index c6ef9997..43e1bdd7 100644 --- a/configs/ml/lbfgs_logistic_regression.yaml +++ b/configs/ml/lbfgs_logistic_regression.yaml @@ -45,7 +45,7 @@ model: mlflow_artifact_path: lbfgs_logistic_regression metadata: - run_name: LBFGS Logistic Regression ${dataset.name} ${kfold_strategy} fold=${val_fold} + run_name: LBFGS Logistic Regression ${dataset.name} ${kfold_strategy} fold=${val_fold} C=${model.C} std=${model.standardize} description: "LBFGS multinomial logistic regression over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." hyperparams: embedding_run_id: ${embedding_run_id} diff --git a/ml/sklearn_linear.py b/ml/sklearn_linear.py index f8fb5368..08ea7546 100644 --- a/ml/sklearn_linear.py +++ b/ml/sklearn_linear.py @@ -3,8 +3,8 @@ from typing import Any import hydra -import joblib import mlflow +import mlflow.sklearn import numpy as np import pandas as pd from matplotlib import pyplot as plt @@ -42,6 +42,8 @@ def run(config: DictConfig) -> None: np.random.seed(config.seed) data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) data.setup("fit") + if "test" in data.datasets: + data.setup("test") x_train = data.train.embeddings y_train = data.train.labels @@ -53,26 +55,42 @@ def run(config: DictConfig) -> None: ] model = _build_model(config) model.fit(x_train, y_train) + _log_convergence(model, config) _log_split_metrics( model, x_val, y_val, data.val.slide_ids, class_names, "validation" ) + if hasattr(data, "test"): + _log_split_metrics( + model, + data.test.embeddings, + data.test.labels, + data.test.slide_ids, + class_names, + "test", + ) _log_model(model) - mlflow.log_params( - { - "model_type": "sklearn_logistic_regression", - "solver": config.model.solver, - "penalty": config.model.penalty, - "C": config.model.C, - "max_iter": config.model.max_iter, - "tol": config.model.tol, - "class_weight": config.model.class_weight, - "standardize": config.model.standardize, - "train_tiles": len(y_train), - "validation_tiles": len(y_val), - } - ) + params = { + "model_type": "sklearn_logistic_regression", + "solver": config.model.solver, + "penalty": config.model.penalty, + "C": config.model.C, + "max_iter": config.model.max_iter, + "tol": config.model.tol, + "class_weight": config.model.class_weight, + "standardize": config.model.standardize, + "val_fold": config.val_fold, + "kfold_strategy": config.kfold_strategy, + "embedding_run_id": config.embedding_run_id, + "kfold_run_id": config.kfold_run_id, + "filter_tiles_run_id": config.filter_tiles_run_id, + "train_tiles": len(y_train), + "validation_tiles": len(y_val), + } + if hasattr(data, "test"): + params["test_tiles"] = len(data.test.labels) + mlflow.log_params(params) def _build_model(config: DictConfig) -> Pipeline: @@ -152,6 +170,27 @@ def _log_split_metrics( out_path = Path(f"{split}_predictions.parquet") predictions.to_parquet(out_path, index=False) mlflow.log_artifact(str(out_path), artifact_path="predictions") + _log_per_slide_accuracy(predictions, split) + + +def _log_per_slide_accuracy(predictions: pd.DataFrame, split: str) -> None: + rows = [] + for slide_id, slide_df in predictions.groupby("slide_id"): + rows.append( + { + "slide_id": slide_id, + "tile_accuracy": float((slide_df["pred"] == slide_df["target"]).mean()), + "n_tiles": len(slide_df), + } + ) + if not rows: + return + + per_slide = pd.DataFrame(rows) + mlflow.log_metric(f"{split}/slide_acc_mean", per_slide["tile_accuracy"].mean()) + mlflow.log_metric(f"{split}/slide_acc_median", per_slide["tile_accuracy"].median()) + mlflow.log_metric(f"{split}/slide_acc_min", per_slide["tile_accuracy"].min()) + mlflow.log_table(per_slide, artifact_file=f"per_slide/{split}_tile_accuracy.json") def _predict_proba_for_all_classes( @@ -166,10 +205,15 @@ def _predict_proba_for_all_classes( return probs +def _log_convergence(model: Pipeline, config: DictConfig) -> None: + classifier = model.named_steps["classifier"] + n_iter = int(classifier.n_iter_.max()) + mlflow.log_metric("n_iter", n_iter) + mlflow.log_param("converged", n_iter < config.model.max_iter) + + def _log_model(model: Pipeline) -> None: - out_path = Path("model.joblib") - joblib.dump(model, out_path) - mlflow.log_artifact(str(out_path), artifact_path="model") + mlflow.sklearn.log_model(model, artifact_path="model") if __name__ == "__main__": diff --git a/scripts/submit_train_linear.py b/scripts/submit_train_linear_probe.py similarity index 61% rename from scripts/submit_train_linear.py rename to scripts/submit_train_linear_probe.py index 93cf6868..1c7b6090 100644 --- a/scripts/submit_train_linear.py +++ b/scripts/submit_train_linear_probe.py @@ -2,7 +2,7 @@ submit_job( - job_name="tissue-classification-train-linear", + job_name="tissue-classification-train-linear-probe", username=..., cpu=8, memory="64Gi", @@ -12,7 +12,12 @@ "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --multirun", + ( + "uv run python -m ml " + "+experiment=ml/..." + "val_fold=0,1,2,3,4 " + "--multirun" + ), ], storage=[storage.secure.PROJECTS], ) diff --git a/scripts/submit_train_logistic_regression.py b/scripts/submit_train_logistic_regression.py new file mode 100644 index 00000000..858ec8cb --- /dev/null +++ b/scripts/submit_train_logistic_regression.py @@ -0,0 +1,24 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-logistic-regression", + username=..., + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + ( + "uv run python -m ml " + "+experiment=ml/..." + "val_fold=0,1,2,3,4 " + "model.C=0.001,0.01,0.1,1,10,100 " + "--multirun" + ), + ], + storage=[storage.secure.PROJECTS], +) From 31ecf6d3de399581e8bbf28a7752e38515b07433 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 21:03:31 +0200 Subject: [PATCH 061/107] fix: submission scripts --- scripts/submit_train_linear_probe.py | 9 ++------- scripts/submit_train_logistic_regression.py | 8 +------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/scripts/submit_train_linear_probe.py b/scripts/submit_train_linear_probe.py index 1c7b6090..93cf6868 100644 --- a/scripts/submit_train_linear_probe.py +++ b/scripts/submit_train_linear_probe.py @@ -2,7 +2,7 @@ submit_job( - job_name="tissue-classification-train-linear-probe", + job_name="tissue-classification-train-linear", username=..., cpu=8, memory="64Gi", @@ -12,12 +12,7 @@ "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - ( - "uv run python -m ml " - "+experiment=ml/..." - "val_fold=0,1,2,3,4 " - "--multirun" - ), + "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --multirun", ], storage=[storage.secure.PROJECTS], ) diff --git a/scripts/submit_train_logistic_regression.py b/scripts/submit_train_logistic_regression.py index 858ec8cb..b0432523 100644 --- a/scripts/submit_train_logistic_regression.py +++ b/scripts/submit_train_logistic_regression.py @@ -12,13 +12,7 @@ "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - ( - "uv run python -m ml " - "+experiment=ml/..." - "val_fold=0,1,2,3,4 " - "model.C=0.001,0.01,0.1,1,10,100 " - "--multirun" - ), + "uv run python -m ml +experiment=ml/... val_fold=0,1,2,3,4 model.C=0.001,0.01,0.1,1,10,100 --multirun", ], storage=[storage.secure.PROJECTS], ) From ff8d0bf45c2ff2aa1f0a6c0ef4ff024bd5e14500 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 21:33:59 +0200 Subject: [PATCH 062/107] feat: implement knn --- .../ml/knn_stratified_group_kfold.yaml | 8 ++ .../experiment/ml/knn_stratified_kfold.yaml | 8 ++ configs/ml/knn.yaml | 62 ++++++++++ ml/__main__.py | 6 + ml/sklearn_knn.py | 111 ++++++++++++++++++ scripts/submit_train_knn.py | 18 +++ 6 files changed, 213 insertions(+) create mode 100644 configs/experiment/ml/knn_stratified_group_kfold.yaml create mode 100644 configs/experiment/ml/knn_stratified_kfold.yaml create mode 100644 configs/ml/knn.yaml create mode 100644 ml/sklearn_knn.py create mode 100644 scripts/submit_train_knn.py diff --git a/configs/experiment/ml/knn_stratified_group_kfold.yaml b/configs/experiment/ml/knn_stratified_group_kfold.yaml new file mode 100644 index 00000000..3c9cafe8 --- /dev/null +++ b/configs/experiment/ml/knn_stratified_group_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/knn + - _self_ + +kfold_strategy: stratified_group +kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/knn_stratified_kfold.yaml b/configs/experiment/ml/knn_stratified_kfold.yaml new file mode 100644 index 00000000..87876c23 --- /dev/null +++ b/configs/experiment/ml/knn_stratified_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/knn + - _self_ + +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/ml/knn.yaml b/configs/ml/knn.yaml new file mode 100644 index 00000000..b2a65613 --- /dev/null +++ b/configs/ml/knn.yaml @@ -0,0 +1,62 @@ +# @package _global_ + +defaults: + - /data: dataset + - /class_mapping: collapse_alterations_to_other + - /ml/data: embedding + - _self_ + +mode: fit +runner: sklearn_knn + +embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} +filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + +train_embedding_uri: runs:/${embedding_run_id}/train/tiles +test_embedding_uri: runs:/${embedding_run_id}/test/tiles +train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet +test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet + +val_fold: 0 + +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.4 + Fat: 0.6 + Epithelium: 0.2 + Muscle: 0.5 + Other: 0.5 + +model: + n_neighbors: 25 + weights: distance + metric: cosine + algorithm: brute + n_jobs: -1 + standardize: false + log_model: false + +mlflow_artifact_path: knn + +metadata: + run_name: kNN ${dataset.name} k=${model.n_neighbors} ${model.weights} ${model.metric} ${kfold_strategy} fold=${val_fold} + description: "kNN classifier over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." + hyperparams: + embedding_run_id: ${embedding_run_id} + kfold_strategy: ${kfold_strategy} + kfold_run_id: ${kfold_run_id} + filter_tiles_run_id: ${filter_tiles_run_id} + val_fold: ${val_fold} + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} + n_neighbors: ${model.n_neighbors} + weights: ${model.weights} + metric: ${model.metric} + algorithm: ${model.algorithm} + n_jobs: ${model.n_jobs} + standardize: ${model.standardize} + log_model: ${model.log_model} diff --git a/ml/__main__.py b/ml/__main__.py index 620c882f..3c4cad04 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -26,6 +26,12 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: run_sklearn_linear(config) mlflow.end_run() return + if config.get("runner") == "sklearn_knn": + from ml.sklearn_knn import run as run_sklearn_knn + + run_sklearn_knn(config) + mlflow.end_run() + return seed_everything(config.seed, workers=True) diff --git a/ml/sklearn_knn.py b/ml/sklearn_knn.py new file mode 100644 index 00000000..e955dd8a --- /dev/null +++ b/ml/sklearn_knn.py @@ -0,0 +1,111 @@ +from random import randint +from typing import Any + +import hydra +import mlflow +import numpy as np +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import autolog +from rationai.mlkit.lightning.loggers import MLFlowLogger +from sklearn.neighbors import KNeighborsClassifier +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler + +from ml.data import DataModule +from ml.sklearn_linear import _log_model, _log_split_metrics + + +if not OmegaConf.has_resolver("random_seed"): + OmegaConf.register_new_resolver( + "random_seed", lambda: randint(0, 2**31), use_cache=True + ) +if not OmegaConf.has_resolver("len"): + OmegaConf.register_new_resolver("len", lambda x: len(x)) + + +@hydra.main(config_path="../configs", config_name="ml", version_base=None) +@autolog +def main(config: DictConfig, logger: MLFlowLogger) -> None: + run(config) + mlflow.end_run() + + +def run(config: DictConfig) -> None: + if config.mode != "fit": + raise ValueError("sklearn_knn currently supports only mode='fit'") + + np.random.seed(config.seed) + data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) + data.setup("fit") + if "test" in data.datasets: + data.setup("test") + + x_train = data.train.embeddings + y_train = data.train.labels + x_val = data.val.embeddings + y_val = data.val.labels + + class_names = [ + n for n, _ in sorted(config.class_indices.items(), key=lambda kv: kv[1]) + ] + model = _build_model(config) + model.fit(x_train, y_train) + + _log_split_metrics( + model, x_val, y_val, data.val.slide_ids, class_names, "validation" + ) + if "test" in data.datasets: + _log_split_metrics( + model, + data.test.embeddings, + data.test.labels, + data.test.slide_ids, + class_names, + "test", + ) + if config.model.get("log_model", False): + _log_model(model) + + params = { + "model_type": "sklearn_knn", + "n_neighbors": config.model.n_neighbors, + "weights": config.model.weights, + "metric": config.model.metric, + "algorithm": config.model.algorithm, + "n_jobs": config.model.n_jobs, + "standardize": config.model.standardize, + "log_model": config.model.get("log_model", False), + "val_fold": config.val_fold, + "kfold_strategy": config.kfold_strategy, + "embedding_run_id": config.embedding_run_id, + "kfold_run_id": config.kfold_run_id, + "filter_tiles_run_id": config.filter_tiles_run_id, + "train_tiles": len(y_train), + "validation_tiles": len(y_val), + } + if "test" in data.datasets: + params["test_tiles"] = len(data.test.labels) + mlflow.log_params(params) + + +def _build_model(config: DictConfig) -> Pipeline: + steps: list[tuple[str, Any]] = [] + if config.model.standardize: + steps.append(("scaler", StandardScaler())) + steps.append( + ( + "classifier", + KNeighborsClassifier( + n_neighbors=config.model.n_neighbors, + weights=config.model.weights, + metric=config.model.metric, + algorithm=config.model.algorithm, + n_jobs=config.model.n_jobs, + ), + ) + ) + return Pipeline(steps) + + +if __name__ == "__main__": + main() diff --git a/scripts/submit_train_knn.py b/scripts/submit_train_knn.py new file mode 100644 index 00000000..5959cd2b --- /dev/null +++ b/scripts/submit_train_knn.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-knn", + username="vcifka", + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml +experiment=ml/knn_stratified_group_kfold val_fold=0,1,2,3,4 model.n_neighbors=1,3,5,11,25,51,101 model.weights=uniform,distance model.metric=cosine,euclidean --multirun", + ], + storage=[storage.secure.PROJECTS], +) From 1f87154bd87ebb2da57ef3811096126e32b08c7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 22:14:45 +0200 Subject: [PATCH 063/107] refactor: focus on convergence --- configs/ml/trainer/default.yaml | 5 +++-- ml/meta_arch.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml index a465a5cf..d3b21718 100644 --- a/configs/ml/trainer/default.yaml +++ b/configs/ml/trainer/default.yaml @@ -1,7 +1,7 @@ # @package _global_ trainer: - max_epochs: 50 + max_epochs: 500 accelerator: auto devices: auto precision: 32 @@ -13,7 +13,8 @@ trainer: _target_: lightning.pytorch.callbacks.EarlyStopping monitor: validation/loss mode: min - patience: 2 + patience: 1 + min_delta: 1.0e-4 model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint monitor: validation/loss diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 07d14311..e7c4c52f 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -7,6 +7,7 @@ import pandas as pd import torch from lightning import LightningModule +from lightning.pytorch.utilities import grad_norm from matplotlib import pyplot as plt from matplotlib.figure import Figure from torch import Tensor, nn @@ -102,6 +103,16 @@ def training_step(self, batch: Input, batch_idx: int) -> Tensor: self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) return loss + def on_before_optimizer_step(self, optimizer: Optimizer) -> None: + norms = grad_norm(self, norm_type=2) + self.log( + "train/grad_norm", + norms["grad_2.0_norm_total"], + on_step=False, + on_epoch=True, + prog_bar=True, + ) + def validation_step(self, batch: Input, batch_idx: int) -> None: inputs, targets, _ = batch outputs = self(inputs) From 7039307d3cd995a1405fbf835f3d19168e5157ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 22:28:57 +0200 Subject: [PATCH 064/107] Remove kNN sklearn baseline --- .../ml/knn_stratified_group_kfold.yaml | 8 -- .../experiment/ml/knn_stratified_kfold.yaml | 8 -- configs/ml/knn.yaml | 62 ---------- ml/__main__.py | 7 -- ml/sklearn_knn.py | 111 ------------------ scripts/submit_train_knn.py | 18 --- 6 files changed, 214 deletions(-) delete mode 100644 configs/experiment/ml/knn_stratified_group_kfold.yaml delete mode 100644 configs/experiment/ml/knn_stratified_kfold.yaml delete mode 100644 configs/ml/knn.yaml delete mode 100644 ml/sklearn_knn.py delete mode 100644 scripts/submit_train_knn.py diff --git a/configs/experiment/ml/knn_stratified_group_kfold.yaml b/configs/experiment/ml/knn_stratified_group_kfold.yaml deleted file mode 100644 index 3c9cafe8..00000000 --- a/configs/experiment/ml/knn_stratified_group_kfold.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: - - /ml/knn - - _self_ - -kfold_strategy: stratified_group -kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/knn_stratified_kfold.yaml b/configs/experiment/ml/knn_stratified_kfold.yaml deleted file mode 100644 index 87876c23..00000000 --- a/configs/experiment/ml/knn_stratified_kfold.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: - - /ml/knn - - _self_ - -kfold_strategy: stratified -kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/ml/knn.yaml b/configs/ml/knn.yaml deleted file mode 100644 index b2a65613..00000000 --- a/configs/ml/knn.yaml +++ /dev/null @@ -1,62 +0,0 @@ -# @package _global_ - -defaults: - - /data: dataset - - /class_mapping: collapse_alterations_to_other - - /ml/data: embedding - - _self_ - -mode: fit -runner: sklearn_knn - -embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} -kfold_strategy: stratified -kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} -filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} - -train_embedding_uri: runs:/${embedding_run_id}/train/tiles -test_embedding_uri: runs:/${embedding_run_id}/test/tiles -train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet -test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet - -val_fold: 0 - -tissue_prop_min: 0.2 -thresholds: - Nerve: 0.0 - Blood: 0.0 - Connective-Tissue: 0.4 - Fat: 0.6 - Epithelium: 0.2 - Muscle: 0.5 - Other: 0.5 - -model: - n_neighbors: 25 - weights: distance - metric: cosine - algorithm: brute - n_jobs: -1 - standardize: false - log_model: false - -mlflow_artifact_path: knn - -metadata: - run_name: kNN ${dataset.name} k=${model.n_neighbors} ${model.weights} ${model.metric} ${kfold_strategy} fold=${val_fold} - description: "kNN classifier over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." - hyperparams: - embedding_run_id: ${embedding_run_id} - kfold_strategy: ${kfold_strategy} - kfold_run_id: ${kfold_run_id} - filter_tiles_run_id: ${filter_tiles_run_id} - val_fold: ${val_fold} - tissue_prop_min: ${tissue_prop_min} - thresholds: ${thresholds} - n_neighbors: ${model.n_neighbors} - weights: ${model.weights} - metric: ${model.metric} - algorithm: ${model.algorithm} - n_jobs: ${model.n_jobs} - standardize: ${model.standardize} - log_model: ${model.log_model} diff --git a/ml/__main__.py b/ml/__main__.py index 3c4cad04..ae5a82e9 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -26,13 +26,6 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: run_sklearn_linear(config) mlflow.end_run() return - if config.get("runner") == "sklearn_knn": - from ml.sklearn_knn import run as run_sklearn_knn - - run_sklearn_knn(config) - mlflow.end_run() - return - seed_everything(config.seed, workers=True) data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) diff --git a/ml/sklearn_knn.py b/ml/sklearn_knn.py deleted file mode 100644 index e955dd8a..00000000 --- a/ml/sklearn_knn.py +++ /dev/null @@ -1,111 +0,0 @@ -from random import randint -from typing import Any - -import hydra -import mlflow -import numpy as np -from omegaconf import DictConfig, OmegaConf -from rationai.mlkit import autolog -from rationai.mlkit.lightning.loggers import MLFlowLogger -from sklearn.neighbors import KNeighborsClassifier -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import StandardScaler - -from ml.data import DataModule -from ml.sklearn_linear import _log_model, _log_split_metrics - - -if not OmegaConf.has_resolver("random_seed"): - OmegaConf.register_new_resolver( - "random_seed", lambda: randint(0, 2**31), use_cache=True - ) -if not OmegaConf.has_resolver("len"): - OmegaConf.register_new_resolver("len", lambda x: len(x)) - - -@hydra.main(config_path="../configs", config_name="ml", version_base=None) -@autolog -def main(config: DictConfig, logger: MLFlowLogger) -> None: - run(config) - mlflow.end_run() - - -def run(config: DictConfig) -> None: - if config.mode != "fit": - raise ValueError("sklearn_knn currently supports only mode='fit'") - - np.random.seed(config.seed) - data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) - data.setup("fit") - if "test" in data.datasets: - data.setup("test") - - x_train = data.train.embeddings - y_train = data.train.labels - x_val = data.val.embeddings - y_val = data.val.labels - - class_names = [ - n for n, _ in sorted(config.class_indices.items(), key=lambda kv: kv[1]) - ] - model = _build_model(config) - model.fit(x_train, y_train) - - _log_split_metrics( - model, x_val, y_val, data.val.slide_ids, class_names, "validation" - ) - if "test" in data.datasets: - _log_split_metrics( - model, - data.test.embeddings, - data.test.labels, - data.test.slide_ids, - class_names, - "test", - ) - if config.model.get("log_model", False): - _log_model(model) - - params = { - "model_type": "sklearn_knn", - "n_neighbors": config.model.n_neighbors, - "weights": config.model.weights, - "metric": config.model.metric, - "algorithm": config.model.algorithm, - "n_jobs": config.model.n_jobs, - "standardize": config.model.standardize, - "log_model": config.model.get("log_model", False), - "val_fold": config.val_fold, - "kfold_strategy": config.kfold_strategy, - "embedding_run_id": config.embedding_run_id, - "kfold_run_id": config.kfold_run_id, - "filter_tiles_run_id": config.filter_tiles_run_id, - "train_tiles": len(y_train), - "validation_tiles": len(y_val), - } - if "test" in data.datasets: - params["test_tiles"] = len(data.test.labels) - mlflow.log_params(params) - - -def _build_model(config: DictConfig) -> Pipeline: - steps: list[tuple[str, Any]] = [] - if config.model.standardize: - steps.append(("scaler", StandardScaler())) - steps.append( - ( - "classifier", - KNeighborsClassifier( - n_neighbors=config.model.n_neighbors, - weights=config.model.weights, - metric=config.model.metric, - algorithm=config.model.algorithm, - n_jobs=config.model.n_jobs, - ), - ) - ) - return Pipeline(steps) - - -if __name__ == "__main__": - main() diff --git a/scripts/submit_train_knn.py b/scripts/submit_train_knn.py deleted file mode 100644 index 5959cd2b..00000000 --- a/scripts/submit_train_knn.py +++ /dev/null @@ -1,18 +0,0 @@ -from kube_jobs import storage, submit_job - - -submit_job( - job_name="tissue-classification-train-knn", - username="vcifka", - cpu=8, - memory="64Gi", - gpu=None, - public=False, - script=[ - "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - "uv run python -m ml +experiment=ml/knn_stratified_group_kfold val_fold=0,1,2,3,4 model.n_neighbors=1,3,5,11,25,51,101 model.weights=uniform,distance model.metric=cosine,euclidean --multirun", - ], - storage=[storage.secure.PROJECTS], -) From 729eccd54f6c1b38744eddf27a35cf6fe4463457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 22:44:32 +0200 Subject: [PATCH 065/107] fix: change monitor to focus on train losss --- configs/ml/trainer/default.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml index d3b21718..8615025e 100644 --- a/configs/ml/trainer/default.yaml +++ b/configs/ml/trainer/default.yaml @@ -11,16 +11,16 @@ trainer: callbacks: early_stopping: _target_: lightning.pytorch.callbacks.EarlyStopping - monitor: validation/loss + monitor: train/loss_epoch mode: min patience: 1 min_delta: 1.0e-4 model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint - monitor: validation/loss + monitor: train/loss_epoch mode: min save_top_k: 1 - filename: "epoch={epoch}-val_loss={validation/loss:.4f}" + filename: "epoch={epoch}-train_loss={train/loss_epoch:.4f}" auto_insert_metric_name: false lr_monitor: _target_: lightning.pytorch.callbacks.LearningRateMonitor From d3ed2ed642159335aa49fa9d7770d72893120126 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 00:20:08 +0200 Subject: [PATCH 066/107] feat: add run name --- configs/ml/linear_classifier.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/ml/linear_classifier.yaml b/configs/ml/linear_classifier.yaml index 5606f0a5..821d786e 100644 --- a/configs/ml/linear_classifier.yaml +++ b/configs/ml/linear_classifier.yaml @@ -35,7 +35,7 @@ thresholds: mlflow_artifact_path: linear_classifier metadata: - run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} + run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} wd=${model.weight_decay} description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." hyperparams: embedding_run_id: ${embedding_run_id} From e9fd559e18ab6f4eb0034d8a4ada68505f6fb3c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 10:46:48 +0200 Subject: [PATCH 067/107] chore: remove logistic regression --- ...tic_regression_stratified_group_kfold.yaml | 8 - ..._logistic_regression_stratified_kfold.yaml | 8 - configs/ml/lbfgs_logistic_regression.yaml | 64 ----- ml/__main__.py | 6 - ml/sklearn_linear.py | 220 ------------------ pyproject.toml | 1 - scripts/submit_train_linear_probe.py | 6 +- scripts/submit_train_logistic_regression.py | 18 -- uv.lock | 2 - 9 files changed, 3 insertions(+), 330 deletions(-) delete mode 100644 configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml delete mode 100644 configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml delete mode 100644 configs/ml/lbfgs_logistic_regression.yaml delete mode 100644 ml/sklearn_linear.py delete mode 100644 scripts/submit_train_logistic_regression.py diff --git a/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml b/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml deleted file mode 100644 index 49c954e0..00000000 --- a/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: - - /ml/lbfgs_logistic_regression - - _self_ - -kfold_strategy: stratified_group -kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml b/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml deleted file mode 100644 index 3d15556d..00000000 --- a/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: - - /ml/lbfgs_logistic_regression - - _self_ - -kfold_strategy: stratified -kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/ml/lbfgs_logistic_regression.yaml b/configs/ml/lbfgs_logistic_regression.yaml deleted file mode 100644 index 43e1bdd7..00000000 --- a/configs/ml/lbfgs_logistic_regression.yaml +++ /dev/null @@ -1,64 +0,0 @@ -# @package _global_ - -defaults: - - /data: dataset - - /class_mapping: collapse_alterations_to_other - - /ml/data: embedding - - _self_ - -mode: fit -runner: sklearn_linear - -embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} -kfold_strategy: stratified -kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} -filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} - -train_embedding_uri: runs:/${embedding_run_id}/train/tiles -test_embedding_uri: runs:/${embedding_run_id}/test/tiles -train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet -test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet - -val_fold: 0 - -tissue_prop_min: 0.2 -thresholds: - Nerve: 0.0 - Blood: 0.0 - Connective-Tissue: 0.4 - Fat: 0.6 - Epithelium: 0.2 - Muscle: 0.5 - Other: 0.5 - -model: - solver: lbfgs - penalty: l2 - C: 1.0 - class_weight: balanced - max_iter: 1000 - tol: 1.0e-4 - n_jobs: null - verbose: 0 - standardize: true - -mlflow_artifact_path: lbfgs_logistic_regression - -metadata: - run_name: LBFGS Logistic Regression ${dataset.name} ${kfold_strategy} fold=${val_fold} C=${model.C} std=${model.standardize} - description: "LBFGS multinomial logistic regression over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." - hyperparams: - embedding_run_id: ${embedding_run_id} - kfold_strategy: ${kfold_strategy} - kfold_run_id: ${kfold_run_id} - filter_tiles_run_id: ${filter_tiles_run_id} - val_fold: ${val_fold} - tissue_prop_min: ${tissue_prop_min} - thresholds: ${thresholds} - solver: ${model.solver} - penalty: ${model.penalty} - C: ${model.C} - class_weight: ${model.class_weight} - max_iter: ${model.max_iter} - tol: ${model.tol} - standardize: ${model.standardize} diff --git a/ml/__main__.py b/ml/__main__.py index ae5a82e9..318c37c4 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -20,12 +20,6 @@ @hydra.main(config_path="../configs", config_name="ml", version_base=None) @autolog def main(config: DictConfig, logger: MLFlowLogger) -> None: - if config.get("runner") == "sklearn_linear": - from ml.sklearn_linear import run as run_sklearn_linear - - run_sklearn_linear(config) - mlflow.end_run() - return seed_everything(config.seed, workers=True) data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) diff --git a/ml/sklearn_linear.py b/ml/sklearn_linear.py deleted file mode 100644 index 08ea7546..00000000 --- a/ml/sklearn_linear.py +++ /dev/null @@ -1,220 +0,0 @@ -from pathlib import Path -from random import randint -from typing import Any - -import hydra -import mlflow -import mlflow.sklearn -import numpy as np -import pandas as pd -from matplotlib import pyplot as plt -from omegaconf import DictConfig, OmegaConf -from rationai.mlkit import autolog -from rationai.mlkit.lightning.loggers import MLFlowLogger -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import confusion_matrix, f1_score, recall_score -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import StandardScaler - -from ml.data import DataModule -from ml.meta_arch import _confmat_figure - - -if not OmegaConf.has_resolver("random_seed"): - OmegaConf.register_new_resolver( - "random_seed", lambda: randint(0, 2**31), use_cache=True - ) -if not OmegaConf.has_resolver("len"): - OmegaConf.register_new_resolver("len", lambda x: len(x)) - - -@hydra.main(config_path="../configs", config_name="ml", version_base=None) -@autolog -def main(config: DictConfig, logger: MLFlowLogger) -> None: - run(config) - mlflow.end_run() - - -def run(config: DictConfig) -> None: - if config.mode != "fit": - raise ValueError("sklearn_linear currently supports only mode='fit'") - - np.random.seed(config.seed) - data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) - data.setup("fit") - if "test" in data.datasets: - data.setup("test") - - x_train = data.train.embeddings - y_train = data.train.labels - x_val = data.val.embeddings - y_val = data.val.labels - - class_names = [ - n for n, _ in sorted(config.class_indices.items(), key=lambda kv: kv[1]) - ] - model = _build_model(config) - model.fit(x_train, y_train) - _log_convergence(model, config) - - _log_split_metrics( - model, x_val, y_val, data.val.slide_ids, class_names, "validation" - ) - if hasattr(data, "test"): - _log_split_metrics( - model, - data.test.embeddings, - data.test.labels, - data.test.slide_ids, - class_names, - "test", - ) - _log_model(model) - - params = { - "model_type": "sklearn_logistic_regression", - "solver": config.model.solver, - "penalty": config.model.penalty, - "C": config.model.C, - "max_iter": config.model.max_iter, - "tol": config.model.tol, - "class_weight": config.model.class_weight, - "standardize": config.model.standardize, - "val_fold": config.val_fold, - "kfold_strategy": config.kfold_strategy, - "embedding_run_id": config.embedding_run_id, - "kfold_run_id": config.kfold_run_id, - "filter_tiles_run_id": config.filter_tiles_run_id, - "train_tiles": len(y_train), - "validation_tiles": len(y_val), - } - if hasattr(data, "test"): - params["test_tiles"] = len(data.test.labels) - mlflow.log_params(params) - - -def _build_model(config: DictConfig) -> Pipeline: - steps: list[tuple[str, Any]] = [] - if config.model.standardize: - steps.append(("scaler", StandardScaler())) - steps.append( - ( - "classifier", - LogisticRegression( - C=config.model.C, - class_weight=config.model.class_weight, - max_iter=config.model.max_iter, - n_jobs=config.model.n_jobs, - penalty=config.model.penalty, - random_state=config.seed, - solver=config.model.solver, - tol=config.model.tol, - verbose=config.model.verbose, - ), - ) - ) - return Pipeline(steps) - - -def _log_split_metrics( - model: Pipeline, - inputs: np.ndarray, - targets: np.ndarray, - slide_ids: np.ndarray, - class_names: list[str], - split: str, -) -> None: - labels = np.arange(len(class_names)) - preds = model.predict(inputs) - probs = _predict_proba_for_all_classes(model, inputs, labels) - - mlflow.log_metric( - f"{split}/acc_macro", - recall_score(targets, preds, labels=labels, average="macro", zero_division=0), - ) - mlflow.log_metric( - f"{split}/f1_macro", - f1_score(targets, preds, average="macro", zero_division=0), - ) - - per_class_acc = recall_score( - targets, preds, labels=labels, average=None, zero_division=0 - ) - per_class_f1 = f1_score( - targets, preds, labels=labels, average=None, zero_division=0 - ) - for cls_name, acc, f1 in zip( - class_names, per_class_acc.tolist(), per_class_f1.tolist(), strict=True - ): - mlflow.log_metric(f"{split}/acc_per_class/{cls_name}", acc) - mlflow.log_metric(f"{split}/f1_per_class/{cls_name}", f1) - - matrix = confusion_matrix(targets, preds, labels=labels) - fig = _confmat_figure(matrix, class_names, title=f"{split} confmat") - try: - mlflow.log_figure(fig, artifact_file=f"confusion_matrix/{split}.png") - finally: - plt.close(fig) - - prob_columns = [f"prob_{c}" for c in class_names] - predictions = pd.DataFrame( - { - "slide_id": slide_ids, - "target": targets, - "pred": preds, - } - ) - predictions = pd.concat( - [predictions, pd.DataFrame(probs, columns=prob_columns)], axis=1 - ) - out_path = Path(f"{split}_predictions.parquet") - predictions.to_parquet(out_path, index=False) - mlflow.log_artifact(str(out_path), artifact_path="predictions") - _log_per_slide_accuracy(predictions, split) - - -def _log_per_slide_accuracy(predictions: pd.DataFrame, split: str) -> None: - rows = [] - for slide_id, slide_df in predictions.groupby("slide_id"): - rows.append( - { - "slide_id": slide_id, - "tile_accuracy": float((slide_df["pred"] == slide_df["target"]).mean()), - "n_tiles": len(slide_df), - } - ) - if not rows: - return - - per_slide = pd.DataFrame(rows) - mlflow.log_metric(f"{split}/slide_acc_mean", per_slide["tile_accuracy"].mean()) - mlflow.log_metric(f"{split}/slide_acc_median", per_slide["tile_accuracy"].median()) - mlflow.log_metric(f"{split}/slide_acc_min", per_slide["tile_accuracy"].min()) - mlflow.log_table(per_slide, artifact_file=f"per_slide/{split}_tile_accuracy.json") - - -def _predict_proba_for_all_classes( - model: Pipeline, inputs: np.ndarray, labels: np.ndarray -) -> np.ndarray: - raw_probs = model.predict_proba(inputs) - probs = np.zeros((len(inputs), len(labels)), dtype=raw_probs.dtype) - for source_idx, class_idx in enumerate(model.classes_): - matching = np.flatnonzero(labels == class_idx) - if len(matching) == 1: - probs[:, matching[0]] = raw_probs[:, source_idx] - return probs - - -def _log_convergence(model: Pipeline, config: DictConfig) -> None: - classifier = model.named_steps["classifier"] - n_iter = int(classifier.n_iter_.max()) - mlflow.log_metric("n_iter", n_iter) - mlflow.log_param("converged", n_iter < config.model.max_iter) - - -def _log_model(model: Pipeline) -> None: - mlflow.sklearn.log_model(model, artifact_path="model") - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index 0cff3865..450c4922 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ dependencies = [ "ratiopath>=1.2.0", "pyarrow>=19.0.1", "datasets>=4.0.0", - "scikit-learn>=1.8.0", "numpy>=2.3.5", "rationai-tiling>=1.1.1", "tifffile>=2025.12.20", diff --git a/scripts/submit_train_linear_probe.py b/scripts/submit_train_linear_probe.py index 93cf6868..3f7ecb1a 100644 --- a/scripts/submit_train_linear_probe.py +++ b/scripts/submit_train_linear_probe.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-train-linear", - username=..., + username="vcifka", cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", + "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --multirun", + "uv run python -m ml +experiment=ml/linear_classifier_stratified_group_kfold val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", ], storage=[storage.secure.PROJECTS], ) diff --git a/scripts/submit_train_logistic_regression.py b/scripts/submit_train_logistic_regression.py deleted file mode 100644 index b0432523..00000000 --- a/scripts/submit_train_logistic_regression.py +++ /dev/null @@ -1,18 +0,0 @@ -from kube_jobs import storage, submit_job - - -submit_job( - job_name="tissue-classification-train-logistic-regression", - username=..., - cpu=8, - memory="64Gi", - gpu=None, - public=False, - script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - "uv run python -m ml +experiment=ml/... val_fold=0,1,2,3,4 model.C=0.001,0.01,0.1,1,10,100 --multirun", - ], - storage=[storage.secure.PROJECTS], -) diff --git a/uv.lock b/uv.lock index c4ad7300..1e1ee3aa 100644 --- a/uv.lock +++ b/uv.lock @@ -2306,7 +2306,6 @@ dependencies = [ { name = "rationai-tiling" }, { name = "ratiopath" }, { name = "ray" }, - { name = "scikit-learn" }, { name = "tifffile" }, { name = "timm" }, { name = "torch" }, @@ -2342,7 +2341,6 @@ requires-dist = [ { name = "rationai-tiling", git = "https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/tiling.git" }, { name = "ratiopath", specifier = ">=1.2.0" }, { name = "ray", specifier = ">=2.51.1" }, - { name = "scikit-learn", specifier = ">=1.8.0" }, { name = "tifffile", specifier = ">=2025.12.20" }, { name = "timm", specifier = ">=1.0.0" }, { name = "torch", specifier = ">=2.0.0" }, From 6dadbd75aa3cda770c14f545266c746226236177 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 11:16:29 +0200 Subject: [PATCH 068/107] feat: implement lbfgs --- ...assifier_lbfgs_stratified_group_kfold.yaml | 26 ++++ ...ear_classifier_lbfgs_stratified_kfold.yaml | 26 ++++ configs/ml/data/embedding.yaml | 2 + configs/ml/linear_classifier.yaml | 6 +- configs/ml/model/linear_classifier.yaml | 10 ++ ml/data/data_module.py | 13 +- ml/meta_arch.py | 144 ++++++++++++++++++ 7 files changed, 223 insertions(+), 4 deletions(-) create mode 100644 configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml create mode 100644 configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml new file mode 100644 index 00000000..85959862 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_stratified_group_kfold + - _self_ + +trainer: + max_epochs: 10 + +data: + batch_size: 1000000000 + train_shuffle: false + train_drop_last: false + +model: + optimizer: lbfgs + learning_rate: 1.0 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml new file mode 100644 index 00000000..bd3c10b3 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_stratified_kfold + - _self_ + +trainer: + max_epochs: 10 + +data: + batch_size: 1000000000 + train_shuffle: false + train_drop_last: false + +model: + optimizer: lbfgs + learning_rate: 1.0 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false diff --git a/configs/ml/data/embedding.yaml b/configs/ml/data/embedding.yaml index 597e012e..40ff4b71 100644 --- a/configs/ml/data/embedding.yaml +++ b/configs/ml/data/embedding.yaml @@ -3,6 +3,8 @@ data: batch_size: 1024 num_workers: 4 + train_shuffle: true + train_drop_last: true train: _target_: ml.data.datasets.EmbeddingTilesDataset diff --git a/configs/ml/linear_classifier.yaml b/configs/ml/linear_classifier.yaml index 821d786e..d3393372 100644 --- a/configs/ml/linear_classifier.yaml +++ b/configs/ml/linear_classifier.yaml @@ -35,7 +35,7 @@ thresholds: mlflow_artifact_path: linear_classifier metadata: - run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} wd=${model.weight_decay} + run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} opt=${model.optimizer} wd=${model.weight_decay} description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." hyperparams: embedding_run_id: ${embedding_run_id} @@ -45,6 +45,10 @@ metadata: val_fold: ${val_fold} tissue_prop_min: ${tissue_prop_min} thresholds: ${thresholds} + optimizer: ${model.optimizer} learning_rate: ${model.learning_rate} weight_decay: ${model.weight_decay} + lbfgs: ${model.lbfgs} batch_size: ${data.batch_size} + train_shuffle: ${data.train_shuffle} + train_drop_last: ${data.train_drop_last} diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index fc9bfedd..4b4d9e83 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -11,5 +11,15 @@ model: class_indices: ${class_indices} + optimizer: adamw learning_rate: 1.0e-4 weight_decay: 0.0 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false diff --git a/ml/data/data_module.py b/ml/data/data_module.py index 7302be54..bfac1182 100644 --- a/ml/data/data_module.py +++ b/ml/data/data_module.py @@ -16,11 +16,18 @@ class DataModule(LightningDataModule): """ def __init__( - self, batch_size: int, num_workers: int = 0, **datasets: DictConfig + self, + batch_size: int, + num_workers: int = 0, + train_shuffle: bool = True, + train_drop_last: bool = True, + **datasets: DictConfig, ) -> None: super().__init__() self.batch_size = batch_size self.num_workers = num_workers + self.train_shuffle = train_shuffle + self.train_drop_last = train_drop_last self.datasets = datasets def setup(self, stage: str) -> None: @@ -42,8 +49,8 @@ def train_dataloader(self) -> Iterable[Input]: return DataLoader( self.train, batch_size=self.batch_size, - shuffle=True, - drop_last=True, + shuffle=self.train_shuffle, + drop_last=self.train_drop_last, num_workers=self.num_workers, persistent_workers=self.num_workers > 0, ) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index e7c4c52f..73cd3f58 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -37,13 +37,21 @@ def __init__( class_indices: dict[str, int], learning_rate: float = 1e-3, weight_decay: float = 0.0, + optimizer: str = "adamw", + lbfgs: dict[str, Any] | None = None, ) -> None: super().__init__() self.save_hyperparameters(ignore=["backbone", "decode_head"]) + if optimizer not in {"adamw", "lbfgs"}: + raise ValueError(f"Unsupported optimizer {optimizer!r}") + if optimizer == "lbfgs": + self.automatic_optimization = False + self.backbone = backbone self.decode_head = decode_head self.criterion: nn.Module + self._lbfgs_batches: list[tuple[Tensor, Tensor]] = [] self.class_names = [ n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) @@ -82,6 +90,8 @@ def setup(self, stage: str) -> None: if stage == "fit": datamodule = cast("Any", self.trainer).datamodule labels = datamodule.train.labels + if self.hparams["optimizer"] == "lbfgs": + self._validate_lbfgs_full_batch(datamodule, len(labels)) num_classes = len(self.class_names) counts = np.bincount(labels, minlength=num_classes).astype(float) counts = np.maximum(counts, 1.0) @@ -97,6 +107,9 @@ def forward(self, x: Tensor) -> Outputs: return self.decode_head(features) def training_step(self, batch: Input, batch_idx: int) -> Tensor: + if self.hparams["optimizer"] == "lbfgs": + return self._lbfgs_training_step(batch, batch_idx) + inputs, targets, _ = batch outputs = self(inputs) loss = self.criterion(outputs, targets) @@ -104,6 +117,8 @@ def training_step(self, batch: Input, batch_idx: int) -> Tensor: return loss def on_before_optimizer_step(self, optimizer: Optimizer) -> None: + if self.hparams["optimizer"] == "lbfgs": + return norms = grad_norm(self, norm_type=2) self.log( "train/grad_norm", @@ -163,12 +178,141 @@ def predict_step( } def configure_optimizers(self) -> Optimizer: + if self.hparams["optimizer"] == "lbfgs": + lbfgs = self.hparams.get("lbfgs") or {} + return torch.optim.LBFGS( + self.parameters(), + lr=self.hparams["learning_rate"], + max_iter=lbfgs.get("max_iter", 100), + max_eval=lbfgs.get("max_eval"), + tolerance_grad=lbfgs.get("tolerance_grad", 1.0e-7), + tolerance_change=lbfgs.get("tolerance_change", 1.0e-9), + history_size=lbfgs.get("history_size", 100), + line_search_fn=lbfgs.get("line_search_fn", "strong_wolfe"), + ) + return torch.optim.AdamW( self.parameters(), lr=self.hparams["learning_rate"], weight_decay=self.hparams["weight_decay"], ) + def _lbfgs_training_step(self, batch: Input, batch_idx: int) -> Tensor: + inputs, targets, _ = batch + self._lbfgs_batches.append(self._prepare_lbfgs_batch(inputs, targets)) + lbfgs = self.hparams.get("lbfgs") or {} + accumulation_steps = int(lbfgs.get("accumulate_batches", 1)) + is_last_batch = batch_idx + 1 == self.trainer.num_training_batches + should_step = len(self._lbfgs_batches) >= accumulation_steps or is_last_batch + if not should_step: + with torch.no_grad(): + return self.criterion(self(inputs), targets) + + optimizer = self.optimizers() + total_samples = sum(targets.numel() for _, targets in self._lbfgs_batches) + + def closure() -> Tensor: + optimizer.zero_grad() + loss, _, _ = self._lbfgs_buffered_loss(total_samples) + if not torch.isfinite(loss): + raise FloatingPointError(f"non-finite LBFGS loss: {loss.item()}") + self.manual_backward(loss) + return loss + + step_loss = optimizer.step(closure=closure) + if not isinstance(step_loss, Tensor): + step_loss = torch.as_tensor(step_loss, device=self.device) + + optimizer.zero_grad() + loss, ce_loss, l2_loss = self._lbfgs_buffered_loss(total_samples) + if not torch.isfinite(loss): + raise FloatingPointError(f"non-finite LBFGS post-step loss: {loss.item()}") + self.manual_backward(loss) + grad_norm = self._total_grad_norm() + if grad_norm is not None and not torch.isfinite(grad_norm): + raise FloatingPointError( + f"non-finite LBFGS gradient norm: {grad_norm.item()}" + ) + optimizer.zero_grad() + self._lbfgs_batches.clear() + + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("train/ce_loss", ce_loss, on_step=True, on_epoch=True) + self.log("train/l2_loss", l2_loss, on_step=True, on_epoch=True) + if grad_norm is not None: + self.log( + "train/grad_norm", + grad_norm, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + self.log("train/lbfgs_step_loss", step_loss.detach(), on_step=True) + return loss + + def _prepare_lbfgs_batch( + self, inputs: Tensor, targets: Tensor + ) -> tuple[Tensor, Tensor]: + lbfgs = self.hparams.get("lbfgs") or {} + if lbfgs.get("accumulate_on_cpu", False): + return inputs.detach().cpu(), targets.detach().cpu() + return inputs, targets + + def _validate_lbfgs_full_batch(self, datamodule: Any, train_size: int) -> None: + lbfgs = self.hparams.get("lbfgs") or {} + batch_size = int(datamodule.batch_size) + accumulation_steps = int(lbfgs.get("accumulate_batches", 1)) + effective_batch_size = batch_size * accumulation_steps + + if datamodule.train_shuffle: + raise ValueError("LBFGS requires data.train_shuffle=false.") + if datamodule.train_drop_last: + raise ValueError("LBFGS requires data.train_drop_last=false.") + if effective_batch_size < train_size: + raise ValueError( + "LBFGS requires a deterministic full-batch objective. Set " + "data.batch_size >= len(train) or set " + "model.lbfgs.accumulate_batches >= ceil(len(train) / " + "data.batch_size). Current effective batch size is " + f"{effective_batch_size} for {train_size} training samples." + ) + + def _lbfgs_buffered_loss(self, total_samples: int) -> tuple[Tensor, Tensor, Tensor]: + ce_loss = torch.zeros((), device=self.device) + for micro_inputs, micro_targets in self._lbfgs_batches: + micro_inputs = micro_inputs.to(self.device) + micro_targets = micro_targets.to(self.device) + outputs = self(micro_inputs) + weight = micro_targets.numel() / total_samples + ce_loss = ce_loss + self.criterion(outputs, micro_targets) * weight + + l2_loss = self._l2_loss() + return ce_loss + l2_loss, ce_loss, l2_loss + + def _objective_loss(self, outputs: Tensor, targets: Tensor) -> Tensor: + return self.criterion(outputs, targets) + self._l2_loss() + + def _l2_loss(self) -> Tensor: + weight_decay = self.hparams["weight_decay"] + if weight_decay == 0: + return torch.zeros((), device=self.device) + + penalty = torch.zeros((), device=self.device) + for name, param in self.named_parameters(): + if param.requires_grad and name.endswith("weight"): + penalty = penalty + param.square().sum() + return 0.5 * weight_decay * penalty + + def _total_grad_norm(self) -> Tensor | None: + grads = [ + param.grad.detach().norm(2) + for param in self.parameters() + if param.grad is not None + ] + if not grads: + return None + return torch.linalg.vector_norm(torch.stack(grads), ord=2) + def _log_per_class(self, collection: MetricCollection, split: str) -> None: computed = collection.compute() for metric_name, values in computed.items(): From d5d3edd1655aa3a32652e4383175a4d9b5259b4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 11:32:41 +0200 Subject: [PATCH 069/107] fix: run id --- configs/data/dataset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 0ab4e0d5..09f8f4a1 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,7 +14,7 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" - stratified_kfold_run_id: "814611e8987d4d569b255b7a4749bc90" + stratified_kfold_run_id: "c7eafdffa32743aa9eb6dd2bf3a185b5" stratified_group_kfold_run_id: "382b41d2fa894514908e8067949c4326" embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" From 216369965c1efc53a985174e240055047e0c267f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 11:37:43 +0200 Subject: [PATCH 070/107] fix: cache the tiles and embeddings so they do not need to be downloaded twice --- ml/data/datasets/embedding_tiles.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 791a7342..5160ba18 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -5,6 +5,7 @@ load time to produce ``(embedding, class_index, slide_id)`` triples. """ +from functools import cache from pathlib import Path import numpy as np @@ -186,7 +187,11 @@ def _filter_metadata( @staticmethod def _resolve_uri(path_or_uri: str | Path) -> str: - s = str(path_or_uri) - if s.startswith(("mlflow-artifacts:/", "runs:/")): - return download_artifacts(artifact_uri=s) - return s + return EmbeddingTilesDataset._resolve_uri_cached(str(path_or_uri)) + + @staticmethod + @cache + def _resolve_uri_cached(uri: str) -> str: + if uri.startswith(("mlflow-artifacts:/", "runs:/")): + return download_artifacts(artifact_uri=uri) + return uri From 92868070c23d24c537ffae7ab965ce4f84f60b44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 12:43:21 +0200 Subject: [PATCH 071/107] fix: limit num of workers --- .../ml/linear_classifier_lbfgs_stratified_group_kfold.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml index 85959862..4d92561f 100644 --- a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml @@ -11,6 +11,7 @@ data: batch_size: 1000000000 train_shuffle: false train_drop_last: false + num_workers: 0 model: optimizer: lbfgs From bb8a043d5dc174768a7aa06ad9ae2de30d6f1bf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 17:28:14 +0200 Subject: [PATCH 072/107] fix: support checkpoint test and prediction export --- ml/callbacks/parquet_prediction_writer.py | 17 +++++++++++------ ml/meta_arch.py | 4 ++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py index 38d86fd4..a6f676bc 100644 --- a/ml/callbacks/parquet_prediction_writer.py +++ b/ml/callbacks/parquet_prediction_writer.py @@ -32,16 +32,21 @@ def write_on_epoch_end( if trainer.global_rank != 0: return + batches = ( + predictions + if not predictions or isinstance(predictions[0], dict) + else [b for dataloader_preds in predictions for b in dataloader_preds] + ) + slide_ids: list[str] = [] targets: list[int] = [] preds: list[int] = [] probs: list[np.ndarray] = [] - for dataloader_preds in predictions: - for b in dataloader_preds: - slide_ids.extend(b["slide_id"]) - targets.extend(b["target"].tolist()) - preds.extend(b["pred"].tolist()) - probs.append(b["probs"].numpy()) + for b in batches: + slide_ids.extend(b["slide_id"]) + targets.extend(b["target"].tolist()) + preds.extend(b["pred"].tolist()) + probs.append(b["probs"].numpy()) if not slide_ids: return diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 73cd3f58..ab6882e9 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -50,13 +50,13 @@ def __init__( self.backbone = backbone self.decode_head = decode_head - self.criterion: nn.Module self._lbfgs_batches: list[tuple[Tensor, Tensor]] = [] self.class_names = [ n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) ] num_classes = len(self.class_names) + self.criterion = nn.CrossEntropyLoss(weight=torch.ones(num_classes)) macro_metrics = MetricCollection( { @@ -208,7 +208,7 @@ def _lbfgs_training_step(self, batch: Input, batch_idx: int) -> Tensor: with torch.no_grad(): return self.criterion(self(inputs), targets) - optimizer = self.optimizers() + optimizer = cast("Any", self.optimizers()) total_samples = sum(targets.numel() for _, targets in self._lbfgs_batches) def closure() -> Tensor: From efddcd6b5f6b11103601e3969e6c3e1a75124520 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 00:07:41 +0200 Subject: [PATCH 073/107] Revert "Merge remote-tracking branch 'origin/feature/linear-probe' into feature/ml-test-mode" This reverts commit c284d8d64c0556309313f66d72f363774465150e, reversing changes made to 811e21c69b4e2128ec98294bc07d4cfc49eaf559. --- .../collapse_alterations_to_other.yaml | 9 - configs/class_mapping/standard.yaml | 11 - configs/data/dataset.yaml | 2 - ...r_probe_collapse_alterations_to_other.yaml | 11 - configs/ml/linear_probe.yaml | 48 --- ml/PLAN_LINEAR_PROBE.md | 188 ----------- ml/data/embeddings_datamodule.py | 308 ------------------ ml/models/__init__.py | 0 ml/models/linear_probe.py | 121 ------- ml/train.py | 76 ----- scripts/submit_linear_probe.py | 18 - 11 files changed, 792 deletions(-) delete mode 100644 configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml delete mode 100644 configs/ml/linear_probe.yaml delete mode 100644 ml/PLAN_LINEAR_PROBE.md delete mode 100644 ml/data/embeddings_datamodule.py delete mode 100644 ml/models/__init__.py delete mode 100644 ml/models/linear_probe.py delete mode 100644 ml/train.py delete mode 100644 scripts/submit_linear_probe.py diff --git a/configs/class_mapping/collapse_alterations_to_other.yaml b/configs/class_mapping/collapse_alterations_to_other.yaml index 66508f1f..160aad92 100644 --- a/configs/class_mapping/collapse_alterations_to_other.yaml +++ b/configs/class_mapping/collapse_alterations_to_other.yaml @@ -42,12 +42,3 @@ class_indices: Epithelium: 4 Muscle: 5 Other: 6 - -class_names: - - Nerve - - Blood - - Connective-Tissue - - Fat - - Epithelium - - Muscle - - Other diff --git a/configs/class_mapping/standard.yaml b/configs/class_mapping/standard.yaml index a623c14e..39866e3c 100644 --- a/configs/class_mapping/standard.yaml +++ b/configs/class_mapping/standard.yaml @@ -46,14 +46,3 @@ class_indices: Inflammation-Chronic: 6 Necrosis: 7 Neoplastic: 8 - -class_names: - - Nerve - - Blood - - Connective-Tissue - - Fat - - Epithelium - - Muscle - - Inflammation-Chronic - - Necrosis - - Neoplastic diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index b487893b..73926866 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -18,8 +18,6 @@ dataset: embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" - embeddings_run_id: "f05076dcd5e64cb2839efe5fb20a22ae" - kfold_run_id: "2e81b0597b614ba8b675e3b34528c1df" exclusions: bad_slides: diff --git a/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml b/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml deleted file mode 100644 index 74c089ab..00000000 --- a/configs/experiment/ml/linear_probe_collapse_alterations_to_other.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# @package _global_ - -defaults: - - /data: dataset - - /class_mapping: collapse_alterations_to_other - - _self_ - -embeddings_run_id: ${dataset.mlflow_artifacts.embeddings_run_id} -kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} -embed_dim: 2560 -n_folds: 5 diff --git a/configs/ml/linear_probe.yaml b/configs/ml/linear_probe.yaml deleted file mode 100644 index 0e2f9cbb..00000000 --- a/configs/ml/linear_probe.yaml +++ /dev/null @@ -1,48 +0,0 @@ -# @package _global_ - -mode: fit - -embed_dim: 2560 -embeddings_run_id: ??? -kfold_run_id: ??? -n_folds: 5 - -trainer: - _target_: rationai.mlkit.lightning.Trainer - max_epochs: 30 - accelerator: auto - devices: 1 - log_every_n_steps: 50 - -data: - _target_: ml.data.embeddings_datamodule.EmbeddingsDataModule - embeddings_run_id: ${embeddings_run_id} - kfold_run_id: ${kfold_run_id} - kfold_artifact_path: kfold_split/kfold_tiles.parquet - class_mapping: ${class_mapping} - class_indices: ${class_indices} - drop_unmapped: true - tissue_prop_min: 0.0 - class_coverage_min: 0.5 - batch_size: 1024 - num_workers: 4 - -model: - _target_: ml.models.linear_probe.LinearProbe - embed_dim: ${embed_dim} - num_classes: ${len:${class_names}} - lr: 1e-3 - weight_decay: 0.0 - class_names: ${class_names} - class_weights: balanced # null | "balanced" | "inverse" | list[float] - -metadata: - run_name: "Linear probe (embed=${embeddings_run_id}, kfold=${kfold_run_id})" - description: Linear probe on cached Virchow2 embeddings, k-fold CV - hyperparams: - embed_dim: ${embed_dim} - lr: ${model.lr} - batch_size: ${data.batch_size} - n_folds: ${n_folds} - tissue_prop_min: ${data.tissue_prop_min} - class_coverage_min: ${data.class_coverage_min} diff --git a/ml/PLAN_LINEAR_PROBE.md b/ml/PLAN_LINEAR_PROBE.md deleted file mode 100644 index 7b9388b3..00000000 --- a/ml/PLAN_LINEAR_PROBE.md +++ /dev/null @@ -1,188 +0,0 @@ -# Linear-Probe Training: Implementation Plan - -Scope of the first PR: **train + k-fold validation only**. Test-set evaluation is a follow-up PR (separate entrypoint, no fold loop, possibly slide-level aggregation). Keeping test out of this PR keeps the held-out set untouched while we tune the probe. - -## 0. Current state (what already exists) - -- `ml/train.py` — Lightning entrypoint, `mode={fit,test}`, instantiates `data` / `model` / `trainer` from Hydra. -- `ml/data/embeddings_datamodule.py` — `EmbeddingsDataModule` that loads `train_dir` / `test_dir` parquet via `datasets.load_dataset`, filters by `fold`, maps `label` → idx. -- `ml/models/linear_probe.py` — `nn.Linear` head, CE loss, accuracy + macro-F1 (`torchmetrics`). -- `configs/ml/linear_probe.yaml` — wires the above; assumes a single parquet dir per split with `embedding`, `label`, `fold`, `tissue_prop` already joined. -- `configs/class_mapping/collapse_alterations_to_other.yaml` — **the mapping we will use for training**: 7 classes (Nerve, Blood, Connective-Tissue, Fat, Epithelium, Muscle, Other). Inflammation/necrosis/neoplastic alterations are collapsed into `Other`. `standard.yaml` is the alternate 9-class mapping; not used here. - -## 1. Gaps to close before this works end-to-end - -### 1.1 Embeddings and labels are not in the same parquet -`embed.py` writes `train/tiles/*.parquet` with columns `slide_id, x, y, embedding`. `kfold_split.py` writes one `kfold_tiles.parquet` with columns `slide_id, x, y, label, tissue_prop, fold` (+ `roi_coverage_*`). The datamodule today assumes everything is in one file. → must **join on `(slide_id, x, y)`**. - -**Recommendation:** do the join lazily in `setup()` using `pyarrow` / `duckdb` over the parquet files (no extra preprocessing script, no extra MLflow run). The labels parquet is small enough to fit in RAM; embeddings stay memory-mapped. - -### 1.2 Inputs come from MLflow artifacts, not local disk -Current config hardcodes `${project_path}/embeddings/${embeddings_run_id}/...`. The other scripts (`embed.py`, `kfold_split.py`) consistently use `mlflow.artifacts.download_artifacts(run_id=..., artifact_path=...)`. → datamodule should accept `embeddings_run_id` + `kfold_run_id` and download to a cache dir on `prepare_data()` (single-process hook). - -### 1.3 Raw labels in kfold parquet are not the 9 canonical classes -`kfold_split.py` writes `label = roi_coverage_` argmax → values like `"EPITHELIUM-BB"`, `"NEOPLASTIC-MALIGNANT"`, or `"background"`. The probe expects canonical names (`Epithelium`, `Neoplastic`, …). → apply `class_mapping` (raw → canonical) inside the datamodule. Tiles whose raw label isn't covered by the mapping (today: `"background"`) need a policy — see §1.4. - -**Update (post-smoke-run):** the *filtered* tiles parquet — which is what kfold now runs over (commit `64b2000`) — already has the ROI columns collapsed upstream (`roi_coverage_Nerve`, `roi_coverage_Epithelium`, …). So `kfold_split.derive_labels` strips the prefix and writes already-canonical names (`"Nerve"`, `"Epithelium"`, …) plus `"background"` straight into `label`. The raw→canonical lookup built from the YAML's BB-suffixed lists matches **none** of these, so with `drop_unmapped=True` the entire dataset gets dropped (observed: `n_tiles_after_label_map = 0` on a 1.1M-tile run). Fix: extend `_raw_to_canonical` with identity entries for every canonical class (`{c: c for c in class_indices}`). This keeps backward-compat with legacy un-collapsed parquets while letting modern canonical-label parquets pass through unchanged. `"background"` stays unmapped → dropped, which is what we want (no ROI overlap = no label). - -### 1.4 Background and coverage-threshold filtering -**Background**: `collapse_alterations_to_other.yaml` has no Background class. `filter_tiles.py` already drops tiles with zero tissue coverage and zero annotation coverage upstream — so by the time we reach the kfold parquet, `"background"` rows can still appear (a tile can have tissue but no annotation overlap, or vice-versa, depending on the filter logic). Drop any rows whose raw label isn't in the mapping. Config knob: `drop_unmapped: true` (default true). - -**Coverage thresholds (live in the datamodule, not a separate PR/script)**: the upstream filter is the coarse cleaning step (any tissue, any annotation). For training-time experimentation, expose two filters as datamodule knobs and apply them after the join, before the train/val split: - -- `tissue_prop_min: float = 0.0` — drop tiles whose total annotation coverage `tissue_prop` is below the threshold. This already exists as a field on the datamodule; keep it. -- `class_coverage_min: float = 0.0` — drop tiles whose **dominant class coverage** (i.e. the `roi_coverage_*` value backing the assigned label, after collapsing per the class mapping) is below the threshold. Forces the label to be "confident" — useful when many tiles are mosaics. - -Both are pure row masks on the labels DataFrame, cheap, and get logged as MLflow params with the run, so threshold sweeps show up cleanly. Rationale for not making this a separate preprocessing PR: these thresholds are experimental knobs you'll sweep alongside LR / weight decay; locking them into a parquet artifact would force a re-preprocessing run for every variant. The fundamental cleaning (any-tissue, any-annotation) stays where it belongs in `filter_tiles.py`. - -To support `class_coverage_min` after class collapsing, the datamodule needs the per-class collapsed coverage. Compute it in `setup()`: for each canonical class C, sum the `roi_coverage_` columns whose raw label maps to C. Then the dominant-class coverage for a tile is `max_C(collapsed_coverage_C)`. The kfold parquet already carries `roi_coverage_*` columns, so no new artifact is needed. - -### 1.5 Config bugs -- `configs/ml/linear_probe.yaml:30` uses `class_mapping.class_names`, which doesn't exist in `standard.yaml`. Either (a) add a `class_names` list to the class-mapping yaml derived from the dict keys, or (b) change the reference. Pick (a) — most readable. -- `${len:...}` resolver — verify it's registered (rationai.mlkit likely does, but confirm by running once). - -### 1.6 K-fold orchestration -User wants **all folds in one MLflow run**. Today `train.py` runs a single fold (`val_fold` param). → wrap fit in a loop over folds, log per-fold metrics under `fold_{i}/...` and write aggregate (`val/acc_mean`, `val/acc_std`, `val/f1_macro_mean`, …) at the end. - -### 1.7 `trainer.test()` after fit -`train.py:21` calls `trainer.test(...)` after `trainer.fit(...)`. Remove for this PR (no test in this PR). Re-introduce in the test-PR, in a separate `mode=test` path that doesn't loop folds. - ---- - -## 2. Concrete step-by-step plan - -### Step 1 — Fix `configs/class_mapping/collapse_alterations_to_other.yaml` -Add a derived `class_names` list (so configs that reference `class_mapping.class_names` work) and switch `linear_probe.yaml` to default to this mapping: -```yaml -class_names: - - Nerve - - Blood - - Connective-Tissue - - Fat - - Epithelium - - Muscle - - Other -``` -Apply the same change to `standard.yaml` for consistency, but `linear_probe.yaml`'s `defaults` should point at `collapse_alterations_to_other`. Keep `class_mapping` (canonical→raw list) and `class_indices` as they are. - -### Step 2 — Rewrite `EmbeddingsDataModule` -Responsibilities, in order: - -1. **`prepare_data()`** (single-process): - - Download embeddings artifact: `mlflow.artifacts.download_artifacts(run_id=embeddings_run_id, artifact_path="train")`. Cache path on `self`. - - Download kfold artifact: `mlflow.artifacts.download_artifacts(run_id=kfold_run_id, artifact_path="/kfold_tiles.parquet")`. -2. **`setup(stage)`**: - - Read kfold parquet into pandas (small: ~few M rows × handful of cols). **Keep** `roi_coverage_*` columns until thresholds are applied. - - Build a `raw → canonical` lookup from the config's `class_mapping` (dict of canonical → list[raw]) and apply it to the `label` column. - - Drop rows whose raw label isn't in the mapping (handles `"background"` and any stragglers) — gated by `drop_unmapped: true`. - - Compute per-tile **collapsed coverage**: for each canonical class C, sum `roi_coverage_` over its raw members. Add a `dominant_coverage` column = the collapsed coverage of the assigned canonical label. - - Apply `tissue_prop_min` and `class_coverage_min` row masks. Log row-count deltas at each step (initial → after raw-label drop → after `tissue_prop_min` → after `class_coverage_min`) as MLflow metrics so threshold sweeps are interpretable. - - Drop `roi_coverage_*` columns once thresholds are done. - - Load embeddings as an Arrow table: `pyarrow.dataset.dataset(emb_dir, format="parquet").to_table(columns=["slide_id","x","y","embedding"])`. Memory-mapped, zero-copy. - - **Join** on `(slide_id, x, y)` via `pyarrow.Table.join(labels_table, keys=["slide_id","x","y"], join_type="inner")`. The two parquets share this key by construction (both downstream of `filter_tiles/train_tiles.parquet`, neither remaps coords), so the inner-join is effectively 1:1. Use `pyarrow.Table.join` rather than `pandas.merge` — the embedding column is heavy (~2560 × 4B × N) and we want to avoid copies. Wrap the joined Arrow table back into `datasets.Dataset(arrow_table=...)`. - - **Verify the join**: log `n_embeddings`, `n_labels`, `n_joined` as MLflow metrics. If `n_joined < n_labels`, log a warning with the gap — it means the embed run dropped tiles (e.g., upstream API failures past retries) and you'll want to know. - - Map `label` → `y` (int) using `class_indices` (or `class_names.index(label)`). - - For the configured `val_fold`, split into `train_set` (`fold != val_fold`) and `val_set` (`fold == val_fold`). `with_format("torch", columns=["embedding", "y"])`. -3. **`train_dataloader` / `val_dataloader`**: as today. Drop `test_dataloader` / `test_dir` arg in this PR (or leave the arg optional `test_dir: Optional[str] = None` with a `NotImplementedError` if requested — cleaner to just remove until the test PR adds it). -4. Make `val_fold` a settable attribute (not just hparam) so the train script can rebuild the data split per fold without reloading the parquet: - - Cache the joined `full_dataset` on the datamodule. - - Expose a `set_val_fold(fold: int)` that re-derives `train_set` / `val_set` by filtering — this avoids re-downloading and re-joining N times. - -### Step 3 — K-fold loop in `ml/train.py` -Refactor `main()` for `mode == "fit"`: -```python -datamodule = instantiate(config.data) -datamodule.prepare_data() -datamodule.setup("fit") # builds full_dataset once - -per_fold_metrics: list[dict] = [] -for fold in range(config.n_folds): - pl.seed_everything(config.seed + fold) # fresh init per fold - datamodule.set_val_fold(fold) - model = instantiate(config.model) - trainer = instantiate(config.trainer, logger=logger) - trainer.fit(model, datamodule=datamodule) - # collect last-epoch val metrics - per_fold_metrics.append({k: float(v) for k, v in trainer.callback_metrics.items() - if k.startswith("val/")}) - # log per-fold - for k, v in per_fold_metrics[-1].items(): - mlflow.log_metric(f"fold_{fold}/{k}", v) - -# aggregate -import numpy as np -keys = per_fold_metrics[0].keys() -for k in keys: - vals = np.array([m[k] for m in per_fold_metrics]) - mlflow.log_metric(f"{k}_mean", vals.mean()) - mlflow.log_metric(f"{k}_std", vals.std()) -``` -Note: `trainer.test()` removed. Keep `mode == "test"` as `raise NotImplementedError("Test mode arrives in the test-set PR")` for now — clearer than silently breaking. - -### Step 4 — Update `configs/ml/linear_probe.yaml` -- Replace `embeddings_run_id` with two run-id fields: - ```yaml - embeddings_run_id: ??? # e.g. f05076dcd5e64cb2839efe5fb20a22ae - kfold_run_id: ??? # e.g. 2e81b0597b614ba8b675e3b34528c1df - ``` -- Add `n_folds: 5` (or wire to read from kfold artifact metadata if available). -- Drop `test_dir` from `data:` block. -- Update `data:` to: - ```yaml - data: - _target_: ml.data.embeddings_datamodule.EmbeddingsDataModule - embeddings_run_id: ${embeddings_run_id} - kfold_run_id: ${kfold_run_id} - kfold_artifact_path: kfold_split/kfold_tiles.parquet # confirm against logged path - class_mapping: ${class_mapping.class_mapping} - class_indices: ${class_indices} - drop_unmapped: true - tissue_prop_min: 0.0 # threshold sweep knob - class_coverage_min: 0.0 # threshold sweep knob - batch_size: 1024 - num_workers: 4 - ``` -- The `defaults` block should select `collapse_alterations_to_other` for `class_mapping`. -- Remove the `val_fold` field at the global level — folds are now driven by the loop. - -### Step 5 — Strengthen `LinearProbe` -Small additions, low risk: -- Per-class F1 logged at `validation_epoch_end`: `F1Score(..., average=None)` → log each class as `val/f1_`. -- Confusion matrix (`torchmetrics.ConfusionMatrix`) computed at end of validation, logged as an MLflow figure or table per fold. -- Optional embedding L2-normalization toggle (`normalize_input: bool`) — Virchow2 outputs are typically not L2-normalized at the CLS-token stage; making this a flag is one line and a common probe variant to try. -- Optional class weights for CE (defer wiring, but leave a `class_weights: Optional[list[float]] = None` parameter). - -### Step 6 — Logging hygiene -- Log artifacts: the resolved class list, the join coverage stats (`#tiles in embeddings`, `#tiles in kfold`, `#joined`, `#dropped_background`, `#dropped_no_label`). -- Log a one-row summary table per fold: `n_train`, `n_val`, label distribution. -- Set `metadata.run_name` to include both run-ids: `"Linear probe (embed=${embeddings_run_id[:8]}, kfold=${kfold_run_id[:8]})"`. - -### Step 7 — Smoke run -End-to-end against the existing artifacts: -``` -embeddings_run_id=f05076dcd5e64cb2839efe5fb20a22ae -kfold_run_id=2e81b0597b614ba8b675e3b34528c1df -embed_dim= -n_folds=5 # confirm against the kfold run's params -``` -Run for 1–2 epochs first to confirm wiring, then full `max_epochs=30`. - ---- - -## 3. Resolved decisions - -1. **Virchow2 embedding dimension** — 2560. -2. **Kfold artifact path** — `kfold_split/kfold_tiles.parquet`. -3. **n_folds** — 5. -4. **Validation cadence** — fit one fold, then move on (sequential). -5. **Reproducibility of `set_val_fold`** — confirmed: re-instantiate model + seed-per-fold (`seed + fold`). - ---- - -## 4. Out of scope for this PR (next PR) - -- Test-set evaluation (single pass, no folds, possibly with slide-level aggregation). -- Fine-tuning beyond the linear head. -- Class-weighted CE / focal loss / soft-label CE on `roi_coverage` proportions. -- Model selection across folds (best ckpt per fold, ensemble at test time). -- Multi-GPU / DDP — single GPU is plenty for a linear probe on cached embeddings. diff --git a/ml/data/embeddings_datamodule.py b/ml/data/embeddings_datamodule.py deleted file mode 100644 index fc2fe74e..00000000 --- a/ml/data/embeddings_datamodule.py +++ /dev/null @@ -1,308 +0,0 @@ -import logging -import warnings -from pathlib import Path -from typing import Any - -import lightning as pl -import mlflow -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.dataset as pad -from datasets import Dataset -from omegaconf import OmegaConf -from torch.utils.data import DataLoader - - -log = logging.getLogger(__name__) - - -class EmbeddingsDataModule(pl.LightningDataModule): - """Linear-probe data module. - - Downloads embeddings and kfold-split artifacts from MLflow, joins them on - (slide_id, x, y), applies class mapping and coverage filters, and exposes - train/val splits per fold via set_val_fold(). - """ - - def __init__( - self, - embeddings_run_id: str, - kfold_run_id: str, - class_mapping: dict[str, list[str]], - class_indices: dict[str, int], - kfold_artifact_path: str = "kfold_split/kfold_tiles.parquet", - drop_unmapped: bool = True, - tissue_prop_min: float = 0.0, - class_coverage_min: float = 0.0, - batch_size: int = 1024, - num_workers: int = 4, - ) -> None: - super().__init__() - self.save_hyperparameters() - self.embeddings_run_id = embeddings_run_id - self.kfold_run_id = kfold_run_id - self.kfold_artifact_path = kfold_artifact_path - self.drop_unmapped = drop_unmapped - self.tissue_prop_min = tissue_prop_min - self.class_coverage_min = class_coverage_min - self.batch_size = batch_size - self.num_workers = num_workers - cm: Any = ( - OmegaConf.to_container(class_mapping, resolve=True) - if OmegaConf.is_config(class_mapping) - else class_mapping - ) - ci: Any = ( - OmegaConf.to_container(class_indices, resolve=True) - if OmegaConf.is_config(class_indices) - else class_indices - ) - self._class_mapping: dict[str, list[str]] = dict(cm) - self._class_indices: dict[str, int] = dict(ci) - self._raw_to_canonical: dict[str, str] = { - raw: canonical - for canonical, raws in self._class_mapping.items() - for raw in raws - } - # Accept already-canonical labels as identity. The filtered tiles parquet - # collapses ROI columns at tiling time, so kfold writes canonical names - # (e.g. "Epithelium") directly into `label`; the raw→canonical lists in - # the class-mapping YAML still cover legacy un-collapsed parquets. - self._raw_to_canonical.update({c: c for c in self._class_indices}) - self._emb_dir: Path | None = None - self._kfold_path: Path | None = None - self.full_dataset: Dataset | None = None - self.train_set: Dataset | None = None - self.val_set: Dataset | None = None - self._val_fold: int = 0 - self._fold_array: np.ndarray | None = None - - # ------------------------------------------------------------------ - # prepare_data - single-process MLflow download - # ------------------------------------------------------------------ - - def prepare_data(self) -> None: - emb_local = mlflow.artifacts.download_artifacts( - run_id=self.embeddings_run_id, - artifact_path="train", - ) - self._emb_dir = Path(emb_local) - - kfold_local = mlflow.artifacts.download_artifacts( - run_id=self.kfold_run_id, - artifact_path=self.kfold_artifact_path, - ) - self._kfold_path = Path(kfold_local) - - # ------------------------------------------------------------------ - # setup - join, filter, build full_dataset once - # ------------------------------------------------------------------ - - def setup(self, stage: str) -> None: - if self.full_dataset is not None: - return # already built; use set_val_fold() to change splits - - assert self._emb_dir is not None and self._kfold_path is not None, ( - "Call prepare_data() before setup()" - ) - - # --- load kfold labels (small) --- - labels_df = pd.read_parquet(self._kfold_path) - n_initial = len(labels_df) - roi_cols = [c for c in labels_df.columns if c.startswith("roi_coverage_")] - - # --- apply raw→canonical label mapping --- - labels_df["label"] = labels_df["label"].map(self._raw_to_canonical) - if self.drop_unmapped: - n_before = len(labels_df) - labels_df = labels_df[labels_df["label"].notna()].copy() - dropped = n_before - len(labels_df) - if dropped: - log.warning("Dropping %d tiles with unmapped labels", dropped) - mlflow.log_metric("n_tiles_initial", n_initial) - mlflow.log_metric("n_tiles_after_label_map", len(labels_df)) - - # --- tissue_prop filter --- - if self.tissue_prop_min > 0.0: - labels_df = labels_df[ - labels_df["tissue_prop"] >= self.tissue_prop_min - ].copy() - mlflow.log_metric("n_tiles_after_tissue_prop_filter", len(labels_df)) - - # --- compute dominant_coverage (coverage of the assigned canonical class) --- - dom_cov = np.zeros(len(labels_df), dtype=np.float32) - label_arr = labels_df["label"].to_numpy() - for canonical, raws in self._class_mapping.items(): - cols = [ - f"roi_coverage_{r}" - for r in raws - if f"roi_coverage_{r}" in labels_df.columns - ] - mask: np.ndarray = label_arr == canonical - if cols and mask.any(): - dom_cov[mask] = labels_df.loc[mask, cols].sum(axis=1).to_numpy() - labels_df["dominant_coverage"] = dom_cov - - if self.class_coverage_min > 0.0: - labels_df = labels_df[ - labels_df["dominant_coverage"] >= self.class_coverage_min - ].copy() - mlflow.log_metric("n_tiles_after_class_coverage_filter", len(labels_df)) - - labels_df = labels_df.drop( - columns=[*roi_cols, "dominant_coverage"], errors="ignore" - ) - - # --- map label → integer class index (column name avoids collision with tile y-coord) --- - labels_df["target"] = labels_df["label"].map(self._class_indices).astype(int) - - # --- load embeddings (memory-mapped Arrow) --- - emb_files = sorted((self._emb_dir / "tiles").rglob("*.parquet")) - emb_table = pad.dataset([str(f) for f in emb_files], format="parquet").to_table( - columns=["slide_id", "x", "y", "embedding"] - ) - log.info("Embeddings schema: %s", emb_table.schema) - mlflow.log_metric("n_embeddings", len(emb_table)) - - # --- inner-join on (slide_id, x, y) --- - labels_table = pa.Table.from_pandas( - labels_df[["slide_id", "x", "y", "fold", "tissue_prop", "target"]], - preserve_index=False, - ) - log.info("Labels schema: %s", labels_table.schema) - mlflow.log_metric("n_labels", len(labels_table)) - - # Acero join requires concrete, matching types on join keys. Normalise both - # tables: slide_id → large_string, x/y → int64. Handles null-type columns - # that can arise when Ray writes parquets with an inferred null schema, and - # type mismatches between string vs large_string across files. - # Acero also does not support list-typed non-key fields, so the embedding - # column is excluded from the join and reattached via row-index lookup. - join_key_types: dict[str, pa.DataType] = { - "slide_id": pa.large_string(), - "x": pa.int64(), - "y": pa.int64(), - } - - emb_keys = emb_table.select(["slide_id", "x", "y"]).append_column( - "_emb_row", pa.array(range(len(emb_table)), type=pa.int64()) - ) - for tbl_name, tbl in [("emb_keys", emb_keys), ("labels", labels_table)]: - for col_name, target_type in join_key_types.items(): - idx = tbl.schema.get_field_index(col_name) - actual_type = tbl.schema.field(col_name).type - if actual_type != target_type: - log.warning( - "%s.%s has type %s, casting to %s", - tbl_name, col_name, actual_type, target_type, - ) - if tbl_name == "emb_keys": - emb_keys = emb_keys.set_column( - idx, col_name, emb_keys.column(col_name).cast(target_type) - ) - else: - labels_table = labels_table.set_column( - idx, col_name, labels_table.column(col_name).cast(target_type) - ) - - joined_meta = emb_keys.join( - labels_table, - keys=["slide_id", "x", "y"], - join_type="inner", - ) - emb_indices = joined_meta.column("_emb_row") - - emb_type = emb_table.schema.field("embedding").type - emb_col = emb_table.column("embedding") - if pa.types.is_list(emb_type): - large_list_type = pa.large_list(emb_type.value_type) - emb_col = emb_col.cast(large_list_type) - emb_type = large_list_type - - joined = joined_meta.drop_columns(["_emb_row"]).append_column( - pa.field("embedding", emb_type), - emb_col.take(emb_indices), - ) - n_joined = len(joined) - mlflow.log_metric("n_joined", n_joined) - - gap = len(labels_table) - n_joined - if gap > 0: - warnings.warn( - f"Join gap: {gap} label tiles have no matching embedding " - "(embed run may have dropped tiles due to upstream failures).", - stacklevel=2, - ) - - self.full_dataset = Dataset(arrow_table=joined) - self._fold_array = np.asarray(joined.column("fold"), dtype=np.int64) - self.set_val_fold(self._val_fold) - - # ------------------------------------------------------------------ - # fold management - # ------------------------------------------------------------------ - - def set_val_fold(self, fold: int) -> None: - self._val_fold = fold - if self.full_dataset is None or self._fold_array is None: - return - train_idx = np.flatnonzero(self._fold_array != fold).tolist() - val_idx = np.flatnonzero(self._fold_array == fold).tolist() - self.train_set = self.full_dataset.select(train_idx).with_format( - "torch", columns=["embedding", "target"] - ) - self.val_set = self.full_dataset.select(val_idx).with_format( - "torch", columns=["embedding", "target"] - ) - - def compute_class_weights(self, method: str = "balanced") -> list[float]: - """Compute per-class loss weights from the current training fold. - - ``balanced`` follows sklearn's ``compute_class_weight``: - ``w_c = n_samples / (n_classes * count_c)``. ``inverse`` uses - ``1 / count_c`` normalised to mean 1. - """ - if self.full_dataset is None or self._fold_array is None: - raise RuntimeError("call setup()/set_val_fold() before compute_class_weights()") - targets = np.asarray(self.full_dataset.data.column("target"), dtype=np.int64) - train_mask = self._fold_array != self._val_fold - train_targets = targets[train_mask] - n_classes = len(self._class_indices) - counts = np.bincount(train_targets, minlength=n_classes).astype(np.float64) - counts = np.maximum(counts, 1.0) - if method == "balanced": - weights = train_targets.size / (n_classes * counts) - elif method == "inverse": - weights = 1.0 / counts - weights = weights / weights.mean() - else: - raise ValueError(f"Unknown class-weight method: {method!r}") - return weights.tolist() - - # ------------------------------------------------------------------ - # dataloaders - # ------------------------------------------------------------------ - - @staticmethod - def _collate(batch: list[dict[str, Any]]) -> tuple[Any, Any]: - import torch - - x = torch.stack([b["embedding"].float() for b in batch]) - y = torch.stack([b["target"].long() for b in batch]) - return x, y - - def _loader(self, ds: Dataset, shuffle: bool) -> DataLoader[Any]: - return DataLoader( - ds, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - collate_fn=self._collate, - ) - - def train_dataloader(self) -> DataLoader[Any]: - return self._loader(self.train_set, shuffle=True) - - def val_dataloader(self) -> DataLoader[Any]: - return self._loader(self.val_set, shuffle=False) diff --git a/ml/models/__init__.py b/ml/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ml/models/linear_probe.py b/ml/models/linear_probe.py deleted file mode 100644 index 710d6ab5..00000000 --- a/ml/models/linear_probe.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import cast - -import lightning as pl -import mlflow -import torch -import torch.nn.functional as F -from torch import nn, optim -from torchmetrics import Accuracy, ConfusionMatrix, F1Score, MetricCollection - - -class LinearProbe(pl.LightningModule): - def __init__( - self, - embed_dim: int, - num_classes: int, - class_names: list[str] | None = None, - lr: float = 1e-3, - weight_decay: float = 0.0, - normalize_input: bool = False, - class_weights: list[float] | None = None, - ) -> None: - super().__init__() - self.save_hyperparameters() - self.embed_dim = embed_dim - self.num_classes = num_classes - self.class_names = class_names - self.lr = lr - self.weight_decay = weight_decay - self.normalize_input = normalize_input - self.head = nn.Linear(embed_dim, num_classes) - - base_metrics = MetricCollection( - { - "acc": Accuracy(task="multiclass", num_classes=num_classes), - "f1_macro": F1Score( - task="multiclass", num_classes=num_classes, average="macro" - ), - } - ) - self.train_metrics = base_metrics.clone(prefix="train/") - self.val_metrics = base_metrics.clone(prefix="val/") - - self.val_f1_per_class = F1Score( - task="multiclass", num_classes=num_classes, average=None - ) - self.val_conf_matrix = ConfusionMatrix( - task="multiclass", num_classes=num_classes - ) - - if class_weights is not None: - self.register_buffer( - "class_weights", torch.tensor(class_weights, dtype=torch.float) - ) - else: - self.class_weights = None - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.normalize_input: - x = F.normalize(x, dim=-1) - return self.head(x) - - def training_step( - self, batch: tuple[torch.Tensor, torch.Tensor], _: int - ) -> torch.Tensor: - x, y = batch - logits = self(x) - loss = F.cross_entropy(logits, y, weight=self.class_weights) - self.log("train/loss", loss, prog_bar=True) - self.log_dict(self.train_metrics(logits, y), prog_bar=True) - return loss - - def validation_step( - self, batch: tuple[torch.Tensor, torch.Tensor], _: int - ) -> torch.Tensor: - x, y = batch - logits = self(x) - loss = F.cross_entropy(logits, y, weight=self.class_weights) - self.log("val/loss", loss, prog_bar=True) - self.log_dict(self.val_metrics(logits, y), prog_bar=True) - self.val_f1_per_class.update(logits, y) - self.val_conf_matrix.update(logits, y) - return loss - - def on_validation_epoch_end(self) -> None: - f1_per_class = cast("torch.Tensor", self.val_f1_per_class.compute()) - class_names = self.class_names or [str(i) for i in range(self.num_classes)] - for name, f1 in zip(class_names, f1_per_class, strict=True): - self.log(f"val/f1_{name}", f1) - - conf_mat = cast("torch.Tensor", self.val_conf_matrix.compute()).cpu().numpy() - try: - import matplotlib.pyplot as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(8, 7)) - im = ax.imshow(conf_mat, interpolation="nearest") - ax.set( - xticks=np.arange(len(class_names)), - yticks=np.arange(len(class_names)), - xticklabels=class_names, - yticklabels=class_names, - xlabel="Predicted", - ylabel="True", - ) - plt.setp(ax.get_xticklabels(), rotation=45, ha="right") - fig.colorbar(im, ax=ax) - fig.tight_layout() - mlflow.log_figure(fig, f"confusion_matrix_epoch{self.current_epoch}.png") - plt.close(fig) - except Exception: - pass - - self.val_f1_per_class.reset() - self.val_conf_matrix.reset() - - def configure_optimizers(self) -> optim.Optimizer: - return optim.AdamW( - self.parameters(), - lr=self.lr, - weight_decay=self.weight_decay, - ) diff --git a/ml/train.py b/ml/train.py deleted file mode 100644 index 0fbfdd03..00000000 --- a/ml/train.py +++ /dev/null @@ -1,76 +0,0 @@ -import secrets -from typing import TYPE_CHECKING, Any - -import hydra -import lightning as pl -import mlflow -import numpy as np -from hydra.utils import instantiate -from omegaconf import DictConfig, OmegaConf -from rationai.mlkit import autolog, with_cli_args -from rationai.mlkit.lightning.loggers import MLFlowLogger - - -if TYPE_CHECKING: - from ml.data.embeddings_datamodule import EmbeddingsDataModule - - -if not OmegaConf.has_resolver("random_seed"): - OmegaConf.register_new_resolver( - "random_seed", lambda: secrets.randbits(32), use_cache=True - ) -if not OmegaConf.has_resolver("len"): - OmegaConf.register_new_resolver("len", len) - - -@with_cli_args(["+ml=linear_probe"]) -@hydra.main(config_path="../configs", config_name="ml", version_base=None) -@autolog -def main(config: DictConfig, logger: MLFlowLogger) -> None: - if config.mode == "fit": - _fit(config, logger) - elif config.mode == "test": - raise NotImplementedError("Test mode arrives in the test-set PR") - else: - raise ValueError(f"Unknown mode: {config.mode}") - - -def _fit(config: DictConfig, logger: MLFlowLogger) -> None: - datamodule: EmbeddingsDataModule = instantiate(config.data) - datamodule.prepare_data() - datamodule.setup("fit") - - per_fold_metrics: list[dict[str, Any]] = [] - for fold in range(config.n_folds): - pl.seed_everything(config.seed + fold) - datamodule.set_val_fold(fold) - - weights_spec = config.model.get("class_weights", None) - weights = ( - datamodule.compute_class_weights(weights_spec) - if isinstance(weights_spec, str) - else weights_spec - ) - model: pl.LightningModule = instantiate(config.model, class_weights=weights) - trainer: pl.Trainer = instantiate(config.trainer, logger=logger) - trainer.fit(model, datamodule=datamodule) - - fold_metrics = { - k: float(v) - for k, v in trainer.callback_metrics.items() - if k.startswith("val/") - } - per_fold_metrics.append(fold_metrics) - for k, v in fold_metrics.items(): - mlflow.log_metric(f"fold_{fold}/{k}", v) - - if per_fold_metrics: - keys = per_fold_metrics[0].keys() - for k in keys: - vals = np.array([m[k] for m in per_fold_metrics]) - mlflow.log_metric(f"{k}_mean", float(vals.mean())) - mlflow.log_metric(f"{k}_std", float(vals.std())) - - -if __name__ == "__main__": - main() diff --git a/scripts/submit_linear_probe.py b/scripts/submit_linear_probe.py deleted file mode 100644 index 61e5c847..00000000 --- a/scripts/submit_linear_probe.py +++ /dev/null @@ -1,18 +0,0 @@ -from kube_jobs import storage, submit_job - - -submit_job( - job_name="tissue-classification-linear-probe", - username="vcifka", - cpu=4, - memory="32Gi", - gpu=None, - public=False, - script=[ - "git clone --branch feature/linear-probe https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - "uv run python -m ml.train +experiment=ml/linear_probe_collapse_alterations_to_other", - ], - storage=[storage.secure.PROJECTS], -) From 14909e24d28ee3f8ccac0a740d45a4bf794d8178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 00:29:53 +0200 Subject: [PATCH 074/107] feat: add functionality to submit final train for both adamw and lbfgs --- .../ml/linear_classifier_final_adamw.yaml | 18 +++ .../ml/linear_classifier_final_lbfgs.yaml | 36 +++++ configs/ml/data/embedding_final.yaml | 2 + docs/stratified_group_kfold.md | 141 ++++++++++++++++++ ml/data/data_module.py | 4 +- scripts/submit_test_linear_final.py | 10 +- scripts/submit_train_linear_final_adamw.py | 18 +++ scripts/submit_train_linear_final_lbfgs.py | 18 +++ 8 files changed, 239 insertions(+), 8 deletions(-) create mode 100644 configs/experiment/ml/linear_classifier_final_adamw.yaml create mode 100644 configs/experiment/ml/linear_classifier_final_lbfgs.yaml create mode 100644 docs/stratified_group_kfold.md create mode 100644 scripts/submit_train_linear_final_adamw.py create mode 100644 scripts/submit_train_linear_final_lbfgs.py diff --git a/configs/experiment/ml/linear_classifier_final_adamw.yaml b/configs/experiment/ml/linear_classifier_final_adamw.yaml new file mode 100644 index 00000000..76e771d3 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_final_adamw.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_final + - override /ml/trainer: default + - _self_ + +# AdamW final: trained to convergence with the same early-stopping rule as the +# k-fold sweep (monitor train/loss_epoch, patience 1, min_delta 1e-4), not a +# fixed 6-epoch budget. weight_decay=1e-3 = best AdamW sweep point (flat curve). +model: + optimizer: adamw + learning_rate: 1.0e-4 + weight_decay: 1.0e-3 + +metadata: + run_name: Final Linear Classifier AdamW ${dataset.name} + description: "Final AdamW linear probe over frozen Virchow2 embeddings, trained on all training folds with early stopping on train/loss_epoch." diff --git a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml new file mode 100644 index 00000000..df42c9c4 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_final + - _self_ + +# LBFGS final: exact solve of the convex objective on the full training batch. +# weight_decay=1e-2 = best LBFGS sweep point. Full-batch guard requires +# train_shuffle=false, train_drop_last=false, batch_size >= len(train); +# num_workers=0 avoids the single-batch IPC deadlock. +trainer: + max_epochs: 10 + +data: + batch_size: 1000000000 + train_shuffle: false + train_drop_last: false + num_workers: 0 + +model: + optimizer: lbfgs + learning_rate: 1.0 + weight_decay: 1.0e-2 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false + +metadata: + run_name: Final Linear Classifier LBFGS ${dataset.name} + description: "Final LBFGS linear probe over frozen Virchow2 embeddings, exact full-batch solve of the convex objective on all training folds." diff --git a/configs/ml/data/embedding_final.yaml b/configs/ml/data/embedding_final.yaml index a1f6ca80..56355946 100644 --- a/configs/ml/data/embedding_final.yaml +++ b/configs/ml/data/embedding_final.yaml @@ -3,6 +3,8 @@ data: batch_size: 1024 num_workers: 4 + train_shuffle: true + train_drop_last: false train: _target_: ml.data.datasets.EmbeddingTilesDataset diff --git a/docs/stratified_group_kfold.md b/docs/stratified_group_kfold.md new file mode 100644 index 00000000..c5e2ffa1 --- /dev/null +++ b/docs/stratified_group_kfold.md @@ -0,0 +1,141 @@ +# Stratified Group K-Fold Split + +## Motivation + +The original tile-level `StratifiedKFold` split balanced tissue labels well, but it allowed tiles from the same slide to appear in both training and validation partitions. Because tiles from a single whole-slide image are not independent, this can leak slide-specific visual patterns into validation and make validation performance overly optimistic. + +To reduce this leakage risk, the splitter now supports `kfold_strategy: stratified_group`, implemented in `split/kfold_split.py`. This mode uses `StratifiedGroupKFold`, stratifying by tissue label while treating `slide_id` as the grouping variable. As a result, all tiles from the same slide are assigned to exactly one validation fold. + +The original tile-level `StratifiedKFold` strategy is still available as `kfold_strategy: stratified`. It can be run when slide-level separation is not required, for example for debugging, comparison against older experiments, or workflows where tile-level stratification is intentionally preferred. + +## Stratification Target and Rare-Class Protocol + +Labels are derived per tile from the `roi_coverage_*` columns: + +- `label` is the tissue class with the highest ROI coverage. +- `background` is assigned when a tile has zero ROI coverage. +- `tissue_prop` is the sum of all `roi_coverage_*` values for the tile. + +For grouped splitting, the important constraint is no longer only the number of tiles per class. `StratifiedGroupKFold` also needs each stratification class to be represented across enough distinct groups. In this project, a group is a slide, so each retained class must appear in at least `n_folds` distinct `slide_id` values. + +The rare-class protocol for `stratified_group` is therefore slide-based: + +- The splitter counts the number of distinct slides containing each label. +- Any label present in fewer than `n_folds` slides is considered rare for grouped splitting. +- All tiles with rare labels are dropped before fold assignment. +- A warning lists each dropped label and the number of slides in which it appears. +- If the rare-class filtering would drop every tile, the script raises a `ValueError`. + +This differs from the older `stratified` strategy. The tile-level strategy collapses rare tile-count classes into `background` only for stratification. The grouped strategy does not collapse rare classes into `background`, because background can be sparse or filtered upstream and because collapsing would not reliably solve the slide-level group constraint. + +## Fold Assignment + +`StratifiedGroupKFold(n_splits=n_folds, shuffle=True, random_state=...)` is fitted with: + +- `y`: the derived tile label. +- `groups`: the tile `slide_id`. + +Each tile is assigned to its validation fold, `fold in [0, n_folds)`. For any fold `k`, the validation set is `fold == k` and the training set is the complement, `fold != k`. + +The group constraint means that each slide appears in only one validation fold. Consequently, no tile from a validation slide appears in the corresponding training split. + +## Output + +The script writes one parquet artifact, `kfold_tiles.parquet`, under the configured `mlflow_artifact_path`. + +The output keeps the filtered input tile dataset and adds fold metadata: + +| Column | Type | Source | +| --- | --- | --- | +| `tissue_prop` | float | Sum of `roi_coverage_*` columns. | +| `fold` | int8 | Validation fold index in `[0, n_folds)`. | + +For `stratified_group`, rare labels may be removed before writing the parquet. When this happens, the logged metric `dropped_rare_class_tiles` records how many tiles were excluded. + +Note: labels are derived inside the splitter for stratification and statistics. The current implementation does not add a new `label` column to the output parquet unless such a column is already present in the input dataset. + +## Logged Statistics + +Per-fold metrics are emitted to MLflow: + +- `fold__train_tiles`: number of tiles outside validation fold `k`. +- `fold__val_tiles`: number of validation tiles in fold `k`. +- `fold__val_tile_pct`: fraction of retained tiles assigned to validation fold `k`. +- `fold__val_slides`: number of distinct validation slides in fold `k`. +- `fold__val_tissue_prop_mean`: mean tissue coverage in validation fold `k`. +- `fold__val_tissue_prop_std`: tissue coverage standard deviation in validation fold `k`. +- `fold_size_cv`: coefficient of variation of validation fold sizes. +- `dropped_rare_class_tiles`: number of dropped rare-class tiles, logged only when rare-class filtering removes tiles. + +The script also logs a label-distribution table as an MLflow artifact: + +- `fold_statistics/label_distribution.json`: fold by original derived label counts. + +Unlike the tile-level `stratified` strategy, the `stratified_group` strategy does not log `fold_statistics/stratification_label_distribution.json`, because it uses the original derived labels directly and does not create a separate collapsed stratification-label array. + +## Split Statistics + +Detailed JSON representations of the metrics are available within the respective MLflow run artifacts. + +### Global Metrics + +- Total retained tiles: 1,102,086 +- Original tiles before rare-class filtering: 1,102,086 +- Dropped rare-class tiles: 0 +- n_folds: 5 +- Random state: 42 +- K-fold strategy: `stratified_group` +- Rare labels dropped before splitting: none +- Fold size CV: 0.0517 + +### Per-Fold Metrics + +| Fold | Train tiles | Val tiles | Val % | Val slides | tissue_prop mean +- std | +| --- | --- | --- | --- | --- | --- | +| 0 | 859,277 | 242,809 | 22.03% | 26 | 0.9077 +- 0.2667 | +| 1 | 890,838 | 211,248 | 19.17% | 26 | 0.8809 +- 0.2981 | +| 2 | 884,420 | 217,666 | 19.75% | 27 | 0.8979 +- 0.2797 | +| 3 | 886,197 | 215,889 | 19.59% | 27 | 0.8705 +- 0.3084 | +| 4 | 887,612 | 214,474 | 19.46% | 31 | 0.8812 +- 0.2978 | + +### Original Label Distribution per Fold + +For `stratified_group`, this table reflects the labels that were retained after rare-class filtering. + +| Label | Fold 0 | Fold 1 | Fold 2 | Fold 3 | Fold 4 | +| --- | --- | --- | --- | --- | --- | +| background | 4.4941% | 6.0493% | 5.2140% | 6.5742% | 6.1728% | +| Blood | 0.4156% | 0.5027% | 0.5035% | 0.4530% | 0.9059% | +| Connective-Tissue | 3.1514% | 3.4386% | 2.9573% | 3.3703% | 3.8252% | +| Epithelium | 1.2413% | 1.4268% | 1.3948% | 1.3887% | 1.4118% | +| Fat | 10.8517% | 10.8948% | 13.8033% | 12.2832% | 10.8428% | +| Muscle | 14.7194% | 2.9955% | 3.2605% | 2.4929% | 3.3776% | +| Nerve | 1.6153% | 1.8968% | 1.8556% | 1.7856% | 1.8725% | +| Other | 63.5112% | 72.7955% | 71.0111% | 71.6521% | 71.5914% | + +### Slide Distribution per Fold + +Use this table to document how slides are distributed across validation folds. This is the key leakage-control diagnostic for the grouped split. + +| Fold | Val slides | Val slide % | Val tiles | Val tile % | +| --- | --- | --- | --- | --- | +| 0 | 26 | 18.98% | 242,809 | 22.03% | +| 1 | 26 | 18.98% | 211,248 | 19.17% | +| 2 | 27 | 19.71% | 217,666 | 19.75% | +| 3 | 27 | 19.71% | 215,889 | 19.59% | +| 4 | 31 | 22.63% | 214,474 | 19.46% | + +### Optional: Label Counts per Fold + +Use this table if you want to report absolute counts in addition to percentages. + +| Label | Fold 0 | Fold 1 | Fold 2 | Fold 3 | Fold 4 | Total | +| --- | --- | --- | --- | --- | --- | --- | +| background | 10,912 | 12,779 | 11,349 | 14,193 | 13,239 | 62,472 | +| Blood | 1,009 | 1,062 | 1,096 | 978 | 1,943 | 6,088 | +| Connective-Tissue | 7,652 | 7,264 | 6,437 | 7,276 | 8,204 | 36,833 | +| Epithelium | 3,014 | 3,014 | 3,036 | 2,998 | 3,028 | 15,090 | +| Fat | 26,349 | 23,015 | 30,045 | 26,518 | 23,255 | 129,182 | +| Muscle | 35,740 | 6,328 | 7,097 | 5,382 | 7,244 | 61,791 | +| Nerve | 3,922 | 4,007 | 4,039 | 3,855 | 4,016 | 19,839 | +| Other | 154,211 | 153,779 | 154,567 | 154,689 | 153,545 | 770,791 | diff --git a/ml/data/data_module.py b/ml/data/data_module.py index 399ae1a9..0f9e527d 100644 --- a/ml/data/data_module.py +++ b/ml/data/data_module.py @@ -59,9 +59,9 @@ def train_dataloader(self) -> Iterable[Input]: persistent_workers=self.num_workers > 0, ) - def val_dataloader(self) -> Iterable[Input] | None: + def val_dataloader(self) -> Iterable[Input]: if self.val is None: - return None + return [] return DataLoader( self.val, batch_size=self.batch_size, diff --git a/scripts/submit_test_linear_final.py b/scripts/submit_test_linear_final.py index 8176035a..c9ad162b 100644 --- a/scripts/submit_test_linear_final.py +++ b/scripts/submit_test_linear_final.py @@ -1,23 +1,21 @@ from kube_jobs import storage, submit_job -checkpoint = ( - "mlflow-artifacts:/104//artifacts/checkpoints/last/checkpoint.ckpt" -) +checkpoint = "mlflow-artifacts:/104//artifacts/checkpoints/last/checkpoint.ckpt" submit_job( job_name="tissue-classification-test-linear-final", - username="vcifka", + username=..., cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", + "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - f'uv run python -m ml +experiment=ml/linear_classifier_final mode=test checkpoint=\\"{checkpoint}\\"', + f'uv run python -m ml +experiment=... mode=test checkpoint=\\"{checkpoint}\\"', ], storage=[storage.secure.PROJECTS], ) diff --git a/scripts/submit_train_linear_final_adamw.py b/scripts/submit_train_linear_final_adamw.py new file mode 100644 index 00000000..7cf5ee41 --- /dev/null +++ b/scripts/submit_train_linear_final_adamw.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-linear-final-adamw", + username="vcifka", + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml +experiment=ml/linear_classifier_final_adamw", + ], + storage=[storage.secure.PROJECTS], +) diff --git a/scripts/submit_train_linear_final_lbfgs.py b/scripts/submit_train_linear_final_lbfgs.py new file mode 100644 index 00000000..871f5d1a --- /dev/null +++ b/scripts/submit_train_linear_final_lbfgs.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-linear-final-lbfgs", + username="vcifka", + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml +experiment=ml/linear_classifier_final_lbfgs", + ], + storage=[storage.secure.PROJECTS], +) From 816736340508d611b0a2e11b0f476c9077ed25e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 10:12:20 +0200 Subject: [PATCH 075/107] feat: implement prediction maps --- .../ml/linear_classifier_test_adamw.yaml | 22 ++ .../ml/linear_classifier_test_lbfgs.yaml | 18 + configs/ml/trainer/final.yaml | 8 + ml/callbacks/__init__.py | 3 +- ml/callbacks/tiff_prediction_map_writer.py | 324 ++++++++++++++++++ ml/meta_arch.py | 113 +----- scripts/submit_test_linear.py | 31 -- scripts/submit_train_linear_final.py | 6 +- scripts/submit_train_linear_final_adamw.py | 18 - scripts/submit_train_linear_final_lbfgs.py | 18 - scripts/submit_train_linear_probe.py | 6 +- 11 files changed, 388 insertions(+), 179 deletions(-) create mode 100644 configs/experiment/ml/linear_classifier_test_adamw.yaml create mode 100644 configs/experiment/ml/linear_classifier_test_lbfgs.yaml create mode 100644 ml/callbacks/tiff_prediction_map_writer.py delete mode 100644 scripts/submit_test_linear.py delete mode 100644 scripts/submit_train_linear_final_adamw.py delete mode 100644 scripts/submit_train_linear_final_lbfgs.py diff --git a/configs/experiment/ml/linear_classifier_test_adamw.yaml b/configs/experiment/ml/linear_classifier_test_adamw.yaml new file mode 100644 index 00000000..ff2431f2 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_test_adamw.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_final_adamw + - _self_ + +# Test the AdamW final checkpoint on the held-out test split. Same model +# architecture as the final run (required for state_dict load); optimizer +# fields are inert at test. Checkpoint is passed on the CLI. +# +# The AdamW final inherits trainer/default (early stopping for fit), which has +# no TIFF map writer. Add it here so the AdamW test produces the same +# WSI-aligned prediction maps as the LBFGS test. +mode: test + +trainer: + callbacks: + tiff_prediction_maps: + _target_: ml.callbacks.TiffPredictionMapWriter + slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet + artifact_path: prediction_maps_tiff + draw_region: central_stride diff --git a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml new file mode 100644 index 00000000..d54b93ce --- /dev/null +++ b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_final_lbfgs + - _self_ + +# Test the LBFGS final checkpoint on the held-out test split. The full-batch +# settings (batch_size 1e9, num_workers 0) are a TRAINING requirement for the +# convex LBFGS solve only; at test there is no optimization, so revert to a +# normal batch to avoid loading the whole test set as one tensor (OOM). +# Same model architecture as the final run (required for state_dict load). +mode: test + +data: + batch_size: 1024 + num_workers: 4 + train_shuffle: true + train_drop_last: true diff --git a/configs/ml/trainer/final.yaml b/configs/ml/trainer/final.yaml index 33db2b99..725803eb 100644 --- a/configs/ml/trainer/final.yaml +++ b/configs/ml/trainer/final.yaml @@ -21,3 +21,11 @@ trainer: prediction_writer: _target_: ml.callbacks.ParquetPredictionWriter output_filename: predictions.parquet + tiff_prediction_maps: + _target_: ml.callbacks.TiffPredictionMapWriter + slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet + artifact_path: prediction_maps_tiff + draw_region: central_stride + slide_selection: all + max_slides: null + write_errors: true diff --git a/ml/callbacks/__init__.py b/ml/callbacks/__init__.py index e9c20c4c..ada21e8e 100644 --- a/ml/callbacks/__init__.py +++ b/ml/callbacks/__init__.py @@ -1,4 +1,5 @@ from ml.callbacks.parquet_prediction_writer import ParquetPredictionWriter +from ml.callbacks.tiff_prediction_map_writer import TiffPredictionMapWriter -__all__ = ["ParquetPredictionWriter"] +__all__ = ["ParquetPredictionWriter", "TiffPredictionMapWriter"] diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py new file mode 100644 index 00000000..a081e2cc --- /dev/null +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -0,0 +1,324 @@ +"""Write tile predictions as WSI-aligned BigTIFF masks.""" + +from collections.abc import Mapping +from pathlib import Path +from re import sub +from tempfile import TemporaryDirectory +from typing import Any, cast + +import lightning as pl +import mlflow +import pandas as pd +import torch +from lightning.pytorch.callbacks import Callback +from mlflow.artifacts import download_artifacts + + +class TiffPredictionMapWriter(Callback): + """Collect test/predict batches and log per-slide prediction TIFFs. + + The output masks use the same coordinate space and MPP as the tiling + ``slides.parquet`` artifact. Class maps store the predicted class index as + ``uint8`` and use ``background_value`` for pixels without predictions. + """ + + def __init__( + self, + slides_uri: str, + artifact_path: str = "prediction_maps_tiff", + background_value: int = 255, + draw_region: str = "central_stride", + write_errors: bool = True, + max_slides: int | None = None, + slide_selection: str = "all", + ) -> None: + super().__init__() + if draw_region not in {"central_stride", "tile"}: + raise ValueError( + "draw_region must be either 'central_stride' or 'tile', " + f"got {draw_region!r}" + ) + if slide_selection not in {"all", "worst"}: + raise ValueError( + "slide_selection must be either 'all' or 'worst', " + f"got {slide_selection!r}" + ) + if max_slides is not None and max_slides <= 0: + raise ValueError(f"max_slides must be positive or None, got {max_slides}") + self.slides_uri = slides_uri + self.artifact_path = artifact_path + self.background_value = background_value + self.draw_region = draw_region + self.write_errors = write_errors + self.max_slides = max_slides + self.slide_selection = slide_selection + self._batches: list[dict[str, Any]] = [] + + def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._batches.clear() + + def on_predict_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + self._batches.clear() + + def on_test_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: torch.Tensor | Mapping[str, Any] | None, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if trainer.global_rank == 0 and isinstance(outputs, Mapping): + self._batches.append(_to_cpu_batch(outputs)) + + def on_predict_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: dict[str, Any] | None, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if trainer.global_rank == 0 and outputs is not None: + self._batches.append(_to_cpu_batch(outputs)) + + def on_test_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + self._write_maps(trainer) + + def on_predict_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + self._write_maps(trainer) + + def _write_maps(self, trainer: pl.Trainer) -> None: + if trainer.global_rank != 0 or not self._batches: + self._batches.clear() + return + + predictions = _batches_to_dataframe(self._batches) + self._batches.clear() + if predictions.empty: + return + + slides = pd.read_parquet(_resolve_uri(self.slides_uri)) + slides_by_id = { + str(row["id"]): cast("dict[str, Any]", row) + for row in slides.to_dict(orient="records") + } + + with TemporaryDirectory(dir=Path(trainer.default_root_dir)) as output_dir: + output_path = Path(output_dir) + Path(output_path, "pred").mkdir(parents=True, exist_ok=True) + if self.write_errors: + Path(output_path, "errors").mkdir(parents=True, exist_ok=True) + + for slide_id, slide_predictions in self._select_slide_groups(predictions): + slide = slides_by_id.get(str(slide_id)) + if slide is None: + raise KeyError( + f"slide_id {slide_id!r} not found in slides artifact " + f"{self.slides_uri!r}" + ) + self._write_slide_maps(slide, slide_predictions, output_path) + + active = mlflow.active_run() + if active is not None: + mlflow.log_artifacts(output_dir, artifact_path=self.artifact_path) + + def _select_slide_groups( + self, predictions: pd.DataFrame + ) -> list[tuple[str, pd.DataFrame]]: + if self.slide_selection == "worst": + predictions = predictions.assign( + _correct=predictions["pred"] == predictions["target"] + ) + slide_ids = ( + predictions.groupby("slide_id", sort=False)["_correct"] + .mean() + .sort_values() + .index + ) + if self.max_slides is not None: + slide_ids = slide_ids[: self.max_slides] + selected = predictions[predictions["slide_id"].isin(slide_ids)] + return [ + (str(slide_id), slide_predictions.drop(columns=["_correct"])) + for slide_id, slide_predictions in selected.groupby( + "slide_id", sort=False + ) + ] + + groups = [ + (str(slide_id), slide_predictions) + for slide_id, slide_predictions in predictions.groupby( + "slide_id", sort=False + ) + ] + return groups[: self.max_slides] if self.max_slides is not None else groups + + def _write_slide_maps( + self, + slide: dict[str, Any], + predictions: pd.DataFrame, + output_path: Path, + ) -> None: + filename = f"{_safe_filename(Path(str(slide['path'])).stem)}.tiff" + size = (int(slide["extent_x"]), int(slide["extent_y"])) + tile_extent = (int(slide["tile_extent_x"]), int(slide["tile_extent_y"])) + stride = (int(slide["stride_x"]), int(slide["stride_y"])) + rows = predictions.to_dict(orient="records") + + _write_vips_prediction_map( + rows=rows, + value_key="pred", + size=size, + tile_extent=tile_extent, + stride=stride, + draw_region=self.draw_region, + background_value=self.background_value, + path=Path(output_path, "pred", filename), + mpp_x=float(slide["mpp_x"]), + mpp_y=float(slide["mpp_y"]), + ) + + if not self.write_errors: + return + + _write_vips_prediction_map( + rows=rows, + value_key="_error", + size=size, + tile_extent=tile_extent, + stride=stride, + draw_region=self.draw_region, + background_value=self.background_value, + path=Path(output_path, "errors", filename), + mpp_x=float(slide["mpp_x"]), + mpp_y=float(slide["mpp_y"]), + ) + + +def _tile_box( + x: int, + y: int, + tile_extent: tuple[int, int], + stride: tuple[int, int], + draw_region: str, +) -> tuple[int, int, int, int]: + if draw_region == "tile": + left = x + top = y + width, height = tile_extent + else: + left = x + (tile_extent[0] - stride[0]) // 2 + top = y + (tile_extent[1] - stride[1]) // 2 + width, height = stride + return (left, top, left + width - 1, top + height - 1) + + +def _clip_box( + box: tuple[int, int, int, int], size: tuple[int, int] +) -> tuple[int, int, int, int] | None: + left, top, right, bottom = box + width, height = size + left = max(0, left) + top = max(0, top) + right = min(width - 1, right) + bottom = min(height - 1, bottom) + if right < left or bottom < top: + return None + return (left, top, right, bottom) + + +def _to_cpu_batch(batch: Mapping[str, Any]) -> dict[str, Any]: + return { + key: value.detach().cpu() if isinstance(value, torch.Tensor) else value + for key, value in batch.items() + } + + +def _batches_to_dataframe(batches: list[dict[str, Any]]) -> pd.DataFrame: + rows: list[pd.DataFrame] = [] + for batch in batches: + rows.append( + pd.DataFrame( + { + "slide_id": list(batch["slide_id"]), + "x": batch["x"].numpy(), + "y": batch["y"].numpy(), + "target": batch["target"].numpy(), + "pred": batch["pred"].numpy(), + } + ) + ) + return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame() + + +def _resolve_uri(uri: str) -> str: + if uri.startswith(("mlflow-artifacts:/", "runs:/")): + return download_artifacts(artifact_uri=uri) + return uri + + +def _safe_filename(value: str) -> str: + return sub(r"[^A-Za-z0-9_.-]+", "_", value) + + +def _write_vips_prediction_map( + rows: list[dict[Any, Any]], + value_key: str, + size: tuple[int, int], + tile_extent: tuple[int, int], + stride: tuple[int, int], + draw_region: str, + background_value: int, + path: Path, + mpp_x: float, + mpp_y: float, +) -> None: + import pyvips + from rationai.masks import write_big_tiff + + path.parent.mkdir(parents=True, exist_ok=True) + width, height = size + vips_image = pyvips.Image.black(width, height).cast(pyvips.BandFormat.UCHAR) + if background_value != 0: + vips_image = vips_image + background_value + + for row in rows: + box = _clip_box( + _tile_box( + x=int(row["x"]), + y=int(row["y"]), + tile_extent=tile_extent, + stride=stride, + draw_region=draw_region, + ), + size=size, + ) + if box is None: + continue + left, top, right, bottom = box + value = ( + int(int(row["pred"]) != int(row["target"])) + if value_key == "_error" + else int(row[value_key]) + ) + drawn = vips_image.draw_rect( + value, + left, + top, + right - left + 1, + bottom - top + 1, + fill=True, + ) + if drawn is not None: + vips_image = drawn + + write_big_tiff(vips_image, path=path, mpp_x=mpp_x, mpp_y=mpp_y) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 86b10eae..95bd41ce 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -1,7 +1,5 @@ from collections import defaultdict from collections.abc import Iterable -from pathlib import Path -from re import sub from typing import Any, cast import mlflow @@ -24,9 +22,6 @@ from ml.typing import Input, Outputs -MAX_TEST_PREDICTION_MAPS = 20 - - class MetaArch(LightningModule): """Top-level classification architecture: backbone + decode_head. @@ -90,7 +85,6 @@ def __init__( self._test_slide_correct: dict[str, int] = defaultdict(int) self._test_slide_total: dict[str, int] = defaultdict(int) - self._test_tile_rows: list[dict[str, Any]] = [] def setup(self, stage: str) -> None: if stage == "fit": @@ -152,7 +146,7 @@ def on_validation_epoch_end(self) -> None: self._log_per_class(self.val_per_class, "validation") self._log_confmat(self.val_confmat, "validation") - def test_step(self, batch: Input, batch_idx: int) -> None: + def test_step(self, batch: Input, batch_idx: int) -> dict[str, Any]: inputs, targets, slide_ids, xs, ys = batch outputs = self(inputs) self.test_metrics.update(outputs, targets) @@ -165,32 +159,20 @@ def test_step(self, batch: Input, batch_idx: int) -> None: for slide_id, ok in zip(slide_ids, correct, strict=True): self._test_slide_correct[slide_id] += int(ok) self._test_slide_total[slide_id] += 1 - self._test_tile_rows.extend( - { - "slide_id": slide_id, - "x": int(x), - "y": int(y), - "target": int(target), - "pred": int(pred), - } - for slide_id, x, y, target, pred in zip( - slide_ids, - xs.cpu().tolist(), - ys.cpu().tolist(), - targets.cpu().tolist(), - preds.cpu().tolist(), - strict=True, - ) - ) + return { + "slide_id": list(slide_ids), + "x": xs.cpu(), + "y": ys.cpu(), + "target": targets.cpu(), + "pred": preds.cpu(), + } def on_test_epoch_end(self) -> None: self._log_per_class(self.test_per_class, "test") self._log_confmat(self.test_confmat, "test") self._log_per_slide_accuracy() - self._log_prediction_maps() self._test_slide_correct.clear() self._test_slide_total.clear() - self._test_tile_rows.clear() def predict_step( self, batch: Input, batch_idx: int, dataloader_idx: int = 0 @@ -386,27 +368,6 @@ def _log_per_slide_accuracy(self) -> None: artifact_file="per_slide/test_tile_accuracy.parquet", ) - def _log_prediction_maps(self) -> None: - if not self._test_tile_rows: - return - df = pd.DataFrame(self._test_tile_rows) - df["_correct"] = df["pred"] == df["target"] - slide_order = ( - df.groupby("slide_id", sort=False)["_correct"] - .mean() - .sort_values() - .head(MAX_TEST_PREDICTION_MAPS) - .index - ) - for slide_id in slide_order: - slide_df = df[df["slide_id"] == slide_id] - fig = _prediction_map_figure(slide_df, self.class_names) - artifact_file = f"prediction_maps/{_safe_filename(slide_id)}.png" - try: - mlflow.log_figure(fig, artifact_file=artifact_file) - finally: - plt.close(fig) - def _confmat_figure( matrix: np.ndarray, class_names: Iterable[str], title: str @@ -440,61 +401,3 @@ def _confmat_figure( fig.colorbar(im, ax=ax) fig.tight_layout() return fig - - -def _prediction_map_figure(df: pd.DataFrame, class_names: list[str]) -> Figure: - fig, axes = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True) - palette = plt.get_cmap("tab10", len(class_names)) - - axes[0].scatter( - df["x"], - df["y"], - c=df["pred"], - cmap=palette, - vmin=-0.5, - vmax=len(class_names) - 0.5, - marker="s", - s=4, - linewidths=0, - ) - axes[0].set_title("Prediction") - - correct = df["pred"].to_numpy() == df["target"].to_numpy() - axes[1].scatter( - df["x"], - df["y"], - c=np.where(correct, 0, 1), - cmap=plt.get_cmap("Set1", 2), - vmin=-0.5, - vmax=1.5, - marker="s", - s=4, - linewidths=0, - ) - axes[1].set_title(f"Errors ({int((~correct).sum())}/{len(df)})") - - handles = [ - plt.Line2D( - [0], - [0], - marker="s", - color="w", - label=cls, - markerfacecolor=palette(i), - markersize=6, - ) - for i, cls in enumerate(class_names) - ] - axes[0].legend(handles=handles, loc="upper left", bbox_to_anchor=(1.02, 1.0)) - - for ax in axes: - ax.set_aspect("equal", adjustable="box") - ax.invert_yaxis() - ax.set_xlabel("x") - ax.set_ylabel("y") - - return fig - - -def _safe_filename(value: str) -> str: - return sub(r"[^A-Za-z0-9_.-]+", "_", Path(value).stem or value) diff --git a/scripts/submit_test_linear.py b/scripts/submit_test_linear.py deleted file mode 100644 index 063952e2..00000000 --- a/scripts/submit_test_linear.py +++ /dev/null @@ -1,31 +0,0 @@ -from kube_jobs import storage, submit_job - - -fold_checkpoints = { - 0: "mlflow-artifacts:/104/26a6f9c741d543c9a09b54048be527a1/artifacts/checkpoints/epoch=3-val_loss=0.1973/checkpoint.ckpt", - 1: "mlflow-artifacts:/104/cc2be862324a446baffd9a8d90be604d/artifacts/checkpoints/epoch=1-val_loss=0.1218/checkpoint.ckpt", - 2: "mlflow-artifacts:/104/8454857b11984419bb7eae02a520ec71/artifacts/checkpoints/epoch=0-val_loss=0.2980/checkpoint.ckpt", - 3: "mlflow-artifacts:/104/bfa52277ea2744b9ab523c56a905dcda/artifacts/checkpoints/epoch=0-val_loss=1.0547/checkpoint.ckpt", - 4: "mlflow-artifacts:/104/358cd6ee286b4d67b7c12cf9bce0c3b4/artifacts/checkpoints/epoch=0-val_loss=0.1462/checkpoint.ckpt", -} - - -submit_job( - job_name="tissue-classification-test-linear", - username="vcifka", - cpu=8, - memory="64Gi", - gpu=None, - public=False, - script=[ - "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - *[ - "uv run python -m ml +experiment=ml/linear_classifier " - f'mode=test val_fold={fold} checkpoint=\\"{checkpoint}\\"' - for fold, checkpoint in fold_checkpoints.items() - ], - ], - storage=[storage.secure.PROJECTS], -) diff --git a/scripts/submit_train_linear_final.py b/scripts/submit_train_linear_final.py index def02bf1..b4c39404 100644 --- a/scripts/submit_train_linear_final.py +++ b/scripts/submit_train_linear_final.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-train-linear-final", - username="vcifka", + username=..., cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", + "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=ml/linear_classifier_final", + "uv run python -m ml +experiment=...", ], storage=[storage.secure.PROJECTS], ) diff --git a/scripts/submit_train_linear_final_adamw.py b/scripts/submit_train_linear_final_adamw.py deleted file mode 100644 index 7cf5ee41..00000000 --- a/scripts/submit_train_linear_final_adamw.py +++ /dev/null @@ -1,18 +0,0 @@ -from kube_jobs import storage, submit_job - - -submit_job( - job_name="tissue-classification-train-linear-final-adamw", - username="vcifka", - cpu=8, - memory="64Gi", - gpu=None, - public=False, - script=[ - "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - "uv run python -m ml +experiment=ml/linear_classifier_final_adamw", - ], - storage=[storage.secure.PROJECTS], -) diff --git a/scripts/submit_train_linear_final_lbfgs.py b/scripts/submit_train_linear_final_lbfgs.py deleted file mode 100644 index 871f5d1a..00000000 --- a/scripts/submit_train_linear_final_lbfgs.py +++ /dev/null @@ -1,18 +0,0 @@ -from kube_jobs import storage, submit_job - - -submit_job( - job_name="tissue-classification-train-linear-final-lbfgs", - username="vcifka", - cpu=8, - memory="64Gi", - gpu=None, - public=False, - script=[ - "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - "uv run python -m ml +experiment=ml/linear_classifier_final_lbfgs", - ], - storage=[storage.secure.PROJECTS], -) diff --git a/scripts/submit_train_linear_probe.py b/scripts/submit_train_linear_probe.py index 3f7ecb1a..dcfc7d93 100644 --- a/scripts/submit_train_linear_probe.py +++ b/scripts/submit_train_linear_probe.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-train-linear", - username="vcifka", + username=..., cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", + "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=ml/linear_classifier_stratified_group_kfold val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", + "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", ], storage=[storage.secure.PROJECTS], ) From 4e45ce1b52577871c7c93718484c7917b12d72c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 11:04:45 +0200 Subject: [PATCH 076/107] fix: change the adamw checkpoint dir name to last --- configs/experiment/ml/linear_classifier_final_adamw.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/configs/experiment/ml/linear_classifier_final_adamw.yaml b/configs/experiment/ml/linear_classifier_final_adamw.yaml index 76e771d3..c798f763 100644 --- a/configs/experiment/ml/linear_classifier_final_adamw.yaml +++ b/configs/experiment/ml/linear_classifier_final_adamw.yaml @@ -13,6 +13,11 @@ model: learning_rate: 1.0e-4 weight_decay: 1.0e-3 +trainer: + callbacks: + model_checkpoint: + save_last: true + metadata: run_name: Final Linear Classifier AdamW ${dataset.name} description: "Final AdamW linear probe over frozen Virchow2 embeddings, trained on all training folds with early stopping on train/loss_epoch." From 8f9ce70099348293bfeea3453dea6d895d5182e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 11:26:01 +0200 Subject: [PATCH 077/107] fix: lower the batch so the compute does not hang --- configs/experiment/ml/linear_classifier_final_lbfgs.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml index df42c9c4..1e18d47e 100644 --- a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml +++ b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml @@ -13,6 +13,7 @@ trainer: data: batch_size: 1000000000 + eval_batch_size: 1024 train_shuffle: false train_drop_last: false num_workers: 0 From 99c2d0d6ea679d01be12622c22a568019dc962d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 12:32:25 +0200 Subject: [PATCH 078/107] fix: put num workers to 0 --- .../ml/linear_classifier_test_adamw.yaml | 8 ++++++++ .../ml/linear_classifier_test_lbfgs.yaml | 16 ++++++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/configs/experiment/ml/linear_classifier_test_adamw.yaml b/configs/experiment/ml/linear_classifier_test_adamw.yaml index ff2431f2..b276759b 100644 --- a/configs/experiment/ml/linear_classifier_test_adamw.yaml +++ b/configs/experiment/ml/linear_classifier_test_adamw.yaml @@ -13,6 +13,14 @@ defaults: # WSI-aligned prediction maps as the LBFGS test. mode: test +# num_workers MUST stay 0. EmbeddingTilesDataset loads the entire split into +# one in-memory numpy array in __init__ and __getitem__ is pure numpy indexing +# (no per-item IO), so workers give zero speedup; num_workers>0 forks the +# parent (pyarrow/mlflow/fsspec thread state + large array) and deadlocks +# before the first test batch. embedding_final defaults to 4; override here. +data: + num_workers: 0 + trainer: callbacks: tiff_prediction_maps: diff --git a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml index d54b93ce..1efd0765 100644 --- a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml +++ b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml @@ -5,14 +5,18 @@ defaults: - _self_ # Test the LBFGS final checkpoint on the held-out test split. The full-batch -# settings (batch_size 1e9, num_workers 0) are a TRAINING requirement for the -# convex LBFGS solve only; at test there is no optimization, so revert to a -# normal batch to avoid loading the whole test set as one tensor (OOM). +# batch_size=1e9 is a TRAINING requirement for the convex LBFGS solve only; at +# test there is no optimization, so revert to a normal batch to avoid loading +# the whole test set as one tensor (OOM). +# +# num_workers MUST stay 0. EmbeddingTilesDataset loads the entire split into +# one in-memory numpy array in __init__ and __getitem__ is pure numpy indexing +# (no per-item IO), so workers give zero speedup; num_workers>0 forks the +# parent (which holds pyarrow/mlflow/fsspec thread state + the large array), +# which deadlocks before the first test batch. # Same model architecture as the final run (required for state_dict load). mode: test data: batch_size: 1024 - num_workers: 4 - train_shuffle: true - train_drop_last: true + num_workers: 0 From 01486bd28c3aeb1c3720b36abd23595c471e2523 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 13:56:04 +0200 Subject: [PATCH 079/107] feat: add prints --- ml/callbacks/tiff_prediction_map_writer.py | 15 ++++++++++++++- ml/data/data_module.py | 8 +++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index a081e2cc..7d07165e 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -118,17 +118,30 @@ def _write_maps(self, trainer: pl.Trainer) -> None: if self.write_errors: Path(output_path, "errors").mkdir(parents=True, exist_ok=True) - for slide_id, slide_predictions in self._select_slide_groups(predictions): + slide_groups = self._select_slide_groups(predictions) + print( + f"[TiffPredictionMapWriter] writing {len(slide_groups)} " + f"prediction map(s)" + ) + for index, (slide_id, slide_predictions) in enumerate(slide_groups, start=1): slide = slides_by_id.get(str(slide_id)) if slide is None: raise KeyError( f"slide_id {slide_id!r} not found in slides artifact " f"{self.slides_uri!r}" ) + print( + f"[TiffPredictionMapWriter] {index}/{len(slide_groups)} " + f"{Path(str(slide['path'])).name}" + ) self._write_slide_maps(slide, slide_predictions, output_path) active = mlflow.active_run() if active is not None: + print( + f"[TiffPredictionMapWriter] logging artifacts to " + f"{self.artifact_path}" + ) mlflow.log_artifacts(output_dir, artifact_path=self.artifact_path) def _select_slide_groups( diff --git a/ml/data/data_module.py b/ml/data/data_module.py index 0f9e527d..43566b98 100644 --- a/ml/data/data_module.py +++ b/ml/data/data_module.py @@ -18,6 +18,7 @@ class DataModule(LightningDataModule): def __init__( self, batch_size: int, + eval_batch_size: int | None = None, num_workers: int = 0, train_shuffle: bool = True, train_drop_last: bool = True, @@ -25,6 +26,7 @@ def __init__( ) -> None: super().__init__() self.batch_size = batch_size + self.eval_batch_size = eval_batch_size or batch_size self.num_workers = num_workers self.train_shuffle = train_shuffle self.train_drop_last = train_drop_last @@ -64,17 +66,17 @@ def val_dataloader(self) -> Iterable[Input]: return [] return DataLoader( self.val, - batch_size=self.batch_size, + batch_size=self.eval_batch_size, num_workers=self.num_workers, persistent_workers=self.num_workers > 0, ) def test_dataloader(self) -> Iterable[Input]: return DataLoader( - self.test, batch_size=self.batch_size, num_workers=self.num_workers + self.test, batch_size=self.eval_batch_size, num_workers=self.num_workers ) def predict_dataloader(self) -> Iterable[Input]: return DataLoader( - self.predict, batch_size=self.batch_size, num_workers=self.num_workers + self.predict, batch_size=self.eval_batch_size, num_workers=self.num_workers ) From 85270fdbb984c34f2786b97228336c8a0e146850 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 14:20:36 +0200 Subject: [PATCH 080/107] feat: add diagnostic prints --- ml/callbacks/tiff_prediction_map_writer.py | 7 +++++++ ml/data/datasets/embedding_tiles.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index 7d07165e..bb9e006b 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -56,6 +56,7 @@ def __init__( def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._batches.clear() + print("[TiffPredictionMapWriter] test loop started", flush=True) def on_predict_start( self, trainer: pl.Trainer, pl_module: pl.LightningModule @@ -73,6 +74,12 @@ def on_test_batch_end( ) -> None: if trainer.global_rank == 0 and isinstance(outputs, Mapping): self._batches.append(_to_cpu_batch(outputs)) + if batch_idx % 50 == 0: + print( + f"[TiffPredictionMapWriter] test batch {batch_idx} " + f"({len(self._batches)} buffered)", + flush=True, + ) def on_predict_batch_end( self, diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 0548dc41..0fd864c0 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -7,6 +7,7 @@ from functools import cache from pathlib import Path +from time import perf_counter import numpy as np import pandas as pd @@ -43,6 +44,15 @@ def __init__( include_folds: list[int] | None = None, exclude_folds: list[int] | None = None, ) -> None: + t0 = perf_counter() + + def _diag(msg: str) -> None: + print( + f"[EmbeddingTilesDataset +{perf_counter() - t0:6.1f}s] {msg}", + flush=True, + ) + + _diag("filtering metadata") meta_df = self._filter_metadata( metadata_uri, thresholds, @@ -50,11 +60,13 @@ def __init__( include_folds, exclude_folds, ) + _diag(f"metadata filtered: {len(meta_df)} rows; reading embeddings") emb_dir = self._resolve_uri(embedding_uri) emb_table = pads.dataset(emb_dir, format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) + _diag(f"embedding table loaded: {emb_table.num_rows} rows") emb_col = emb_table.column("embedding") if pa.types.is_list(emb_col.type): @@ -75,6 +87,7 @@ def __init__( del emb_keys, meta_table if joined_keys.num_rows == 0: raise RuntimeError("inner join with embeddings produced empty dataset") + _diag(f"join done: {joined_keys.num_rows} matched rows; filling embeddings") _idx_col = joined_keys.column("_emb_idx") if isinstance(_idx_col, pa.ChunkedArray): @@ -117,6 +130,7 @@ def __init__( self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() self.xs = joined_keys.column("x").to_pandas().to_numpy(dtype=np.int64) self.ys = joined_keys.column("y").to_pandas().to_numpy(dtype=np.int64) + _diag(f"dataset ready: {len(self.labels)} samples, dim={embeddings.shape[1]}") def __len__(self) -> int: return len(self.labels) From 5db671ce703053d22ec0c69484c98ffdd056a16e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 14:34:04 +0200 Subject: [PATCH 081/107] fix: use numpy buffer --- ml/callbacks/tiff_prediction_map_writer.py | 27 +++++++++++----------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index bb9e006b..203b315e 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -8,6 +8,7 @@ import lightning as pl import mlflow +import numpy as np import pandas as pd import torch from lightning.pytorch.callbacks import Callback @@ -130,7 +131,9 @@ def _write_maps(self, trainer: pl.Trainer) -> None: f"[TiffPredictionMapWriter] writing {len(slide_groups)} " f"prediction map(s)" ) - for index, (slide_id, slide_predictions) in enumerate(slide_groups, start=1): + for index, (slide_id, slide_predictions) in enumerate( + slide_groups, start=1 + ): slide = slides_by_id.get(str(slide_id)) if slide is None: raise KeyError( @@ -307,9 +310,11 @@ def _write_vips_prediction_map( path.parent.mkdir(parents=True, exist_ok=True) width, height = size - vips_image = pyvips.Image.black(width, height).cast(pyvips.BandFormat.UCHAR) - if background_value != 0: - vips_image = vips_image + background_value + + # Draw into a single numpy buffer. Chaining pyvips.draw_rect rebuilds the + # full-extent image per tile (O(n_tiles * width * height)) and hangs on WSI + # extents; numpy slice assignment is one allocation + O(total pixels). + buffer = np.full((height, width), background_value, dtype=np.uint8) for row in rows: box = _clip_box( @@ -330,15 +335,9 @@ def _write_vips_prediction_map( if value_key == "_error" else int(row[value_key]) ) - drawn = vips_image.draw_rect( - value, - left, - top, - right - left + 1, - bottom - top + 1, - fill=True, - ) - if drawn is not None: - vips_image = drawn + buffer[top : bottom + 1, left : right + 1] = value + vips_image = pyvips.Image.new_from_memory( + buffer.tobytes(), width, height, 1, "uchar" + ) write_big_tiff(vips_image, path=path, mpp_x=mpp_x, mpp_y=mpp_y) From 3aea3c2d3038a43823fc59f372cc5509f3d122ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 15:40:10 +0200 Subject: [PATCH 082/107] refactor: use HeatmapAssembler --- ml/callbacks/tiff_prediction_map_writer.py | 185 ++++++++++----------- 1 file changed, 88 insertions(+), 97 deletions(-) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index 203b315e..9f10fecf 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -34,10 +34,11 @@ def __init__( slide_selection: str = "all", ) -> None: super().__init__() - if draw_region not in {"central_stride", "tile"}: + if draw_region != "central_stride": raise ValueError( - "draw_region must be either 'central_stride' or 'tile', " - f"got {draw_region!r}" + "draw_region must be 'central_stride'; 'tile' is unsupported " + "for class maps (overlapping tiles would average categorical " + f"class indices). got {draw_region!r}" ) if slide_selection not in {"all", "worst"}: raise ValueError( @@ -192,71 +193,111 @@ def _write_slide_maps( output_path: Path, ) -> None: filename = f"{_safe_filename(Path(str(slide['path'])).stem)}.tiff" - size = (int(slide["extent_x"]), int(slide["extent_y"])) + extent = (int(slide["extent_x"]), int(slide["extent_y"])) tile_extent = (int(slide["tile_extent_x"]), int(slide["tile_extent_y"])) stride = (int(slide["stride_x"]), int(slide["stride_y"])) - rows = predictions.to_dict(orient="records") + mpp = (float(slide["mpp_x"]), float(slide["mpp_y"])) - _write_vips_prediction_map( - rows=rows, - value_key="pred", - size=size, + xs = predictions["x"].to_numpy(dtype=np.int64) + ys = predictions["y"].to_numpy(dtype=np.int64) + preds = predictions["pred"].to_numpy(dtype=np.int64) + + _write_assembled_map( + values=preds, + xs=xs, + ys=ys, + extent=extent, tile_extent=tile_extent, stride=stride, - draw_region=self.draw_region, background_value=self.background_value, path=Path(output_path, "pred", filename), - mpp_x=float(slide["mpp_x"]), - mpp_y=float(slide["mpp_y"]), + mpp=mpp, ) if not self.write_errors: return - _write_vips_prediction_map( - rows=rows, - value_key="_error", - size=size, + errors = ( + predictions["pred"].to_numpy() != predictions["target"].to_numpy() + ).astype(np.int64) + _write_assembled_map( + values=errors, + xs=xs, + ys=ys, + extent=extent, tile_extent=tile_extent, stride=stride, - draw_region=self.draw_region, background_value=self.background_value, path=Path(output_path, "errors", filename), - mpp_x=float(slide["mpp_x"]), - mpp_y=float(slide["mpp_y"]), + mpp=mpp, ) -def _tile_box( - x: int, - y: int, +def _write_assembled_map( + values: np.ndarray, + xs: np.ndarray, + ys: np.ndarray, + extent: tuple[int, int], tile_extent: tuple[int, int], stride: tuple[int, int], - draw_region: str, -) -> tuple[int, int, int, int]: - if draw_region == "tile": - left = x - top = y - width, height = tile_extent - else: - left = x + (tile_extent[0] - stride[0]) // 2 - top = y + (tile_extent[1] - stride[1]) // 2 - width, height = stride - return (left, top, left + width - 1, top + height - 1) - - -def _clip_box( - box: tuple[int, int, int, int], size: tuple[int, int] -) -> tuple[int, int, int, int] | None: - left, top, right, bottom = box - width, height = size - left = max(0, left) - top = max(0, top) - right = min(width - 1, right) - bottom = min(height - 1, bottom) - if right < left or bottom < top: - return None - return (left, top, right, bottom) + background_value: int, + path: Path, + mpp: tuple[float, float], +) -> None: + """Assemble per-tile scalar predictions into a WSI-aligned uint8 BigTIFF. + + Uses ``HeatmapAssembler`` with the tile footprint set to ``stride`` so + tiles are non-overlapping (``central_stride`` semantics): the count grid + stays <= 1, so no categorical class-index averaging occurs. The assembler + keeps a GCD-compressed grid (extent / stride), avoiding a full-extent + in-RAM buffer. Pixels never covered by a tile are written as + ``background_value``; the grid is recentered by ``(tile - stride) // 2`` + on embed to match the tile's central receptive region. + """ + import pyvips + from rationai.masks import write_big_tiff + from rationai.masks.heatmap_assembler import HeatmapAssembler + + path.parent.mkdir(parents=True, exist_ok=True) + extent_x, extent_y = extent + stride_x, stride_y = stride + + assembler = HeatmapAssembler( + extent_x, + extent_y, + stride_x, + stride_y, + stride_x, + stride_y, + dtype=torch.float32, + ) + assembler.update( + torch.from_numpy(values.astype(np.float32)), + torch.from_numpy(xs), + torch.from_numpy(ys), + ) + + grid = assembler.compute().round().to(torch.uint8).numpy() + grid[assembler._count.numpy() == 0] = background_value + grid = np.ascontiguousarray(grid) + + mask = pyvips.Image.new_from_array(grid).cast(pyvips.BandFormat.UCHAR) + mask = mask.resize( + assembler.common_divisor_x, + vscale=assembler.common_divisor_y, + kernel=pyvips.enums.Kernel.NEAREST, + ) + offset_x = (tile_extent[0] - stride_x) // 2 + offset_y = (tile_extent[1] - stride_y) // 2 + mask = mask.embed( + offset_x, + offset_y, + extent_x, + extent_y, + extend=pyvips.enums.Extend.BACKGROUND, + background=[background_value], + ) + write_big_tiff(mask, path, mpp[0], mpp[1]) def _to_cpu_batch(batch: Mapping[str, Any]) -> dict[str, Any]: @@ -291,53 +332,3 @@ def _resolve_uri(uri: str) -> str: def _safe_filename(value: str) -> str: return sub(r"[^A-Za-z0-9_.-]+", "_", value) - - -def _write_vips_prediction_map( - rows: list[dict[Any, Any]], - value_key: str, - size: tuple[int, int], - tile_extent: tuple[int, int], - stride: tuple[int, int], - draw_region: str, - background_value: int, - path: Path, - mpp_x: float, - mpp_y: float, -) -> None: - import pyvips - from rationai.masks import write_big_tiff - - path.parent.mkdir(parents=True, exist_ok=True) - width, height = size - - # Draw into a single numpy buffer. Chaining pyvips.draw_rect rebuilds the - # full-extent image per tile (O(n_tiles * width * height)) and hangs on WSI - # extents; numpy slice assignment is one allocation + O(total pixels). - buffer = np.full((height, width), background_value, dtype=np.uint8) - - for row in rows: - box = _clip_box( - _tile_box( - x=int(row["x"]), - y=int(row["y"]), - tile_extent=tile_extent, - stride=stride, - draw_region=draw_region, - ), - size=size, - ) - if box is None: - continue - left, top, right, bottom = box - value = ( - int(int(row["pred"]) != int(row["target"])) - if value_key == "_error" - else int(row[value_key]) - ) - buffer[top : bottom + 1, left : right + 1] = value - - vips_image = pyvips.Image.new_from_memory( - buffer.tobytes(), width, height, 1, "uchar" - ) - write_big_tiff(vips_image, path=path, mpp_x=mpp_x, mpp_y=mpp_y) From 756642a27ac80d8706141933b9e3805b9103c8d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 17:18:55 +0200 Subject: [PATCH 083/107] chore: clean config structure --- .../ml/linear_classifier_final_adamw.yaml | 2 +- .../ml/linear_classifier_final_lbfgs.yaml | 2 +- .../ml/linear_classifier_final.yaml | 0 ml/callbacks/tiff_prediction_map_writer.py | 223 +++++++++++++----- ml/meta_arch.py | 1 + ..._linear_final.py => submit_test_linear.py} | 6 +- 6 files changed, 170 insertions(+), 64 deletions(-) rename configs/{experiment => }/ml/linear_classifier_final.yaml (100%) rename scripts/{submit_test_linear_final.py => submit_test_linear.py} (57%) diff --git a/configs/experiment/ml/linear_classifier_final_adamw.yaml b/configs/experiment/ml/linear_classifier_final_adamw.yaml index c798f763..2d0feba8 100644 --- a/configs/experiment/ml/linear_classifier_final_adamw.yaml +++ b/configs/experiment/ml/linear_classifier_final_adamw.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - /experiment/ml/linear_classifier_final + - /ml/linear_classifier_final - override /ml/trainer: default - _self_ diff --git a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml index 1e18d47e..f83f3626 100644 --- a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml +++ b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - /experiment/ml/linear_classifier_final + - /ml/linear_classifier_final - _self_ # LBFGS final: exact solve of the convex objective on the full training batch. diff --git a/configs/experiment/ml/linear_classifier_final.yaml b/configs/ml/linear_classifier_final.yaml similarity index 100% rename from configs/experiment/ml/linear_classifier_final.yaml rename to configs/ml/linear_classifier_final.yaml diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index 9f10fecf..d6254088 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -200,10 +200,10 @@ def _write_slide_maps( xs = predictions["x"].to_numpy(dtype=np.int64) ys = predictions["y"].to_numpy(dtype=np.int64) - preds = predictions["pred"].to_numpy(dtype=np.int64) + probs = np.stack(predictions["probs"].to_numpy()).astype(np.float32) - _write_assembled_map( - values=preds, + _write_class_map( + probs=probs, xs=xs, ys=ys, extent=extent, @@ -219,9 +219,9 @@ def _write_slide_maps( errors = ( predictions["pred"].to_numpy() != predictions["target"].to_numpy() - ).astype(np.int64) - _write_assembled_map( - values=errors, + ).astype(np.float32) + _write_error_map( + errors=errors, xs=xs, ys=ys, extent=extent, @@ -233,8 +233,133 @@ def _write_slide_maps( ) -def _write_assembled_map( - values: np.ndarray, +class _ClassVoteAssembler: + """Confidence-weighted per-class accumulator over the full tile footprint. + + Mirrors ``HeatmapAssembler``'s GCD-compressed grid + ``(x, y)`` -> ROI + mapping, but accumulates the per-class softmax of every tile that covers a + cell instead of summing a categorical index. Overlapping strided tiles are + fused by ``argmax`` over the summed probabilities, so heavy stride overlap + (e.g. tile 224 / stride 112) is used instead of discarded. ``count == 0`` + marks never-covered cells as background. + """ + + def __init__( + self, + extent_x: int, + extent_y: int, + tile_x: int, + tile_y: int, + stride_x: int, + stride_y: int, + n_classes: int, + ) -> None: + from math import gcd + + self.cdx = gcd(stride_x, tile_x) + self.cdy = gcd(stride_y, tile_y) + self._sx = stride_x // self.cdx + self._sy = stride_y // self.cdy + self._tx = tile_x // self.cdx + self._ty = tile_y // self.cdy + self._gx = extent_x // self.cdx + self._gy = extent_y // self.cdy + self._acc = torch.zeros( + n_classes, self._gy, self._gx, dtype=torch.float32 + ) + self._count = torch.zeros(self._gy, self._gx, dtype=torch.int32) + + def update( + self, probs: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor + ) -> None: + for prob, x, y in zip(probs, xs, ys, strict=False): + cx = int(x.item()) // self.cdx + cy = int(y.item()) // self.cdy + x0, y0 = cx * self._sx, cy * self._sy + x1 = min(x0 + self._tx, self._gx) + y1 = min(y0 + self._ty, self._gy) + if x1 <= x0 or y1 <= y0: + continue + self._acc[:, y0:y1, x0:x1] += prob[:, None, None] + self._count[y0:y1, x0:x1] += 1 + + def labels(self, background_value: int) -> np.ndarray: + labels = self._acc.argmax(0).to(torch.uint8) + labels[self._count == 0] = background_value + return np.ascontiguousarray(labels.numpy()) + + +def _emit_mask( + grid: np.ndarray, + cdx: int, + cdy: int, + extent: tuple[int, int], + background_value: int, + path: Path, + mpp: tuple[float, float], +) -> None: + """Upscale the GCD-compressed grid back to WSI extent and write a BigTIFF. + + NEAREST resize by the GCD factor (tile footprint placed at its true + position, no recenter), background-padded to the full extent. + """ + import pyvips + from rationai.masks import write_big_tiff + + path.parent.mkdir(parents=True, exist_ok=True) + mask = pyvips.Image.new_from_array(grid).cast(pyvips.BandFormat.UCHAR) + mask = mask.resize( + cdx, vscale=cdy, kernel=pyvips.enums.Kernel.NEAREST + ) + mask = mask.embed( + 0, + 0, + extent[0], + extent[1], + extend=pyvips.enums.Extend.BACKGROUND, + background=[background_value], + ) + write_big_tiff(mask, path, mpp[0], mpp[1]) + + +def _write_class_map( + probs: np.ndarray, + xs: np.ndarray, + ys: np.ndarray, + extent: tuple[int, int], + tile_extent: tuple[int, int], + stride: tuple[int, int], + background_value: int, + path: Path, + mpp: tuple[float, float], +) -> None: + assembler = _ClassVoteAssembler( + extent[0], + extent[1], + tile_extent[0], + tile_extent[1], + stride[0], + stride[1], + n_classes=probs.shape[1], + ) + assembler.update( + torch.from_numpy(probs), + torch.from_numpy(xs), + torch.from_numpy(ys), + ) + _emit_mask( + assembler.labels(background_value), + assembler.cdx, + assembler.cdy, + extent, + background_value, + path, + mpp, + ) + + +def _write_error_map( + errors: np.ndarray, xs: np.ndarray, ys: np.ndarray, extent: tuple[int, int], @@ -244,60 +369,38 @@ def _write_assembled_map( path: Path, mpp: tuple[float, float], ) -> None: - """Assemble per-tile scalar predictions into a WSI-aligned uint8 BigTIFF. - - Uses ``HeatmapAssembler`` with the tile footprint set to ``stride`` so - tiles are non-overlapping (``central_stride`` semantics): the count grid - stays <= 1, so no categorical class-index averaging occurs. The assembler - keeps a GCD-compressed grid (extent / stride), avoiding a full-extent - in-RAM buffer. Pixels never covered by a tile are written as - ``background_value``; the grid is recentered by ``(tile - stride) // 2`` - on embed to match the tile's central receptive region. + """Per-tile error (pred != target) averaged over the full tile footprint. + + Averaging a 0/1 error fraction across overlapping tiles is meaningful + (fraction of covering tiles that were wrong); thresholded back to {0, 1}. """ - import pyvips - from rationai.masks import write_big_tiff from rationai.masks.heatmap_assembler import HeatmapAssembler - path.parent.mkdir(parents=True, exist_ok=True) - extent_x, extent_y = extent - stride_x, stride_y = stride - assembler = HeatmapAssembler( - extent_x, - extent_y, - stride_x, - stride_y, - stride_x, - stride_y, + extent[0], + extent[1], + tile_extent[0], + tile_extent[1], + stride[0], + stride[1], dtype=torch.float32, ) assembler.update( - torch.from_numpy(values.astype(np.float32)), + torch.from_numpy(errors), torch.from_numpy(xs), torch.from_numpy(ys), ) - - grid = assembler.compute().round().to(torch.uint8).numpy() + grid = (assembler.compute() >= 0.5).to(torch.uint8).numpy() grid[assembler._count.numpy() == 0] = background_value - grid = np.ascontiguousarray(grid) - - mask = pyvips.Image.new_from_array(grid).cast(pyvips.BandFormat.UCHAR) - mask = mask.resize( + _emit_mask( + np.ascontiguousarray(grid), assembler.common_divisor_x, - vscale=assembler.common_divisor_y, - kernel=pyvips.enums.Kernel.NEAREST, + assembler.common_divisor_y, + extent, + background_value, + path, + mpp, ) - offset_x = (tile_extent[0] - stride_x) // 2 - offset_y = (tile_extent[1] - stride_y) // 2 - mask = mask.embed( - offset_x, - offset_y, - extent_x, - extent_y, - extend=pyvips.enums.Extend.BACKGROUND, - background=[background_value], - ) - write_big_tiff(mask, path, mpp[0], mpp[1]) def _to_cpu_batch(batch: Mapping[str, Any]) -> dict[str, Any]: @@ -310,17 +413,19 @@ def _to_cpu_batch(batch: Mapping[str, Any]) -> dict[str, Any]: def _batches_to_dataframe(batches: list[dict[str, Any]]) -> pd.DataFrame: rows: list[pd.DataFrame] = [] for batch in batches: - rows.append( - pd.DataFrame( - { - "slide_id": list(batch["slide_id"]), - "x": batch["x"].numpy(), - "y": batch["y"].numpy(), - "target": batch["target"].numpy(), - "pred": batch["pred"].numpy(), - } - ) + frame = pd.DataFrame( + { + "slide_id": list(batch["slide_id"]), + "x": batch["x"].numpy(), + "y": batch["y"].numpy(), + "target": batch["target"].numpy(), + "pred": batch["pred"].numpy(), + } ) + # per-row softmax vector kept as an object column so it survives the + # per-slide groupby; stacked back to (n, C) in the assembler. + frame["probs"] = list(batch["probs"].numpy()) + rows.append(frame) return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame() diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 95bd41ce..a20cf4ce 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -165,6 +165,7 @@ def test_step(self, batch: Input, batch_idx: int) -> dict[str, Any]: "y": ys.cpu(), "target": targets.cpu(), "pred": preds.cpu(), + "probs": outputs.softmax(dim=1).cpu(), } def on_test_epoch_end(self) -> None: diff --git a/scripts/submit_test_linear_final.py b/scripts/submit_test_linear.py similarity index 57% rename from scripts/submit_test_linear_final.py rename to scripts/submit_test_linear.py index c9ad162b..c7860485 100644 --- a/scripts/submit_test_linear_final.py +++ b/scripts/submit_test_linear.py @@ -1,11 +1,11 @@ from kube_jobs import storage, submit_job -checkpoint = "mlflow-artifacts:/104//artifacts/checkpoints/last/checkpoint.ckpt" +checkpoint = "mlflow-artifacts:/104//artifacts/checkpoints/last/checkpoint.ckpt" submit_job( - job_name="tissue-classification-test-linear-final", + job_name="tissue-classification-test-linear-final-...", username=..., cpu=8, memory="64Gi", @@ -15,7 +15,7 @@ "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - f'uv run python -m ml +experiment=... mode=test checkpoint=\\"{checkpoint}\\"', + f'uv run python -m ml +experiment=... checkpoint=\\"{checkpoint}\\"', ], storage=[storage.secure.PROJECTS], ) From 2771e78ff520184eeda417e5d503a340a05796e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 17:27:18 +0200 Subject: [PATCH 084/107] fix: prediction maps class indices --- ml/callbacks/tiff_prediction_map_writer.py | 49 +++++++++++++++------- scripts/submit_test_linear.py | 10 ++--- 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index d6254088..4e58b700 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -209,7 +209,6 @@ def _write_slide_maps( extent=extent, tile_extent=tile_extent, stride=stride, - background_value=self.background_value, path=Path(output_path, "pred", filename), mpp=mpp, ) @@ -227,7 +226,6 @@ def _write_slide_maps( extent=extent, tile_extent=tile_extent, stride=stride, - background_value=self.background_value, path=Path(output_path, "errors", filename), mpp=mpp, ) @@ -264,6 +262,7 @@ def __init__( self._ty = tile_y // self.cdy self._gx = extent_x // self.cdx self._gy = extent_y // self.cdy + self._n_classes = n_classes self._acc = torch.zeros( n_classes, self._gy, self._gx, dtype=torch.float32 ) @@ -283,10 +282,17 @@ def update( self._acc[:, y0:y1, x0:x1] += prob[:, None, None] self._count[y0:y1, x0:x1] += 1 - def labels(self, background_value: int) -> np.ndarray: - labels = self._acc.argmax(0).to(torch.uint8) - labels[self._count == 0] = background_value - return np.ascontiguousarray(labels.numpy()) + def labels(self) -> np.ndarray: + """Encode like ``remap_annotation_masks``: class ``i`` (0-based) -> + ``round(255 * (i + 1) / n_classes)``, never-covered pixels -> ``0``. + + The reporting tool expects GT and prediction masks in the same + evenly-spread value space, so this must mirror that LUT exactly. + """ + idx = self._acc.argmax(0).numpy() + out = _spread_lut(self._n_classes)[idx].astype(np.uint8) + out[self._count.numpy() == 0] = 0 + return np.ascontiguousarray(out) def _emit_mask( @@ -329,7 +335,6 @@ def _write_class_map( extent: tuple[int, int], tile_extent: tuple[int, int], stride: tuple[int, int], - background_value: int, path: Path, mpp: tuple[float, float], ) -> None: @@ -348,11 +353,11 @@ def _write_class_map( torch.from_numpy(ys), ) _emit_mask( - assembler.labels(background_value), + assembler.labels(), assembler.cdx, assembler.cdy, extent, - background_value, + 0, path, mpp, ) @@ -365,14 +370,15 @@ def _write_error_map( extent: tuple[int, int], tile_extent: tuple[int, int], stride: tuple[int, int], - background_value: int, path: Path, mpp: tuple[float, float], ) -> None: """Per-tile error (pred != target) averaged over the full tile footprint. Averaging a 0/1 error fraction across overlapping tiles is meaningful - (fraction of covering tiles that were wrong); thresholded back to {0, 1}. + (fraction of covering tiles that were wrong); thresholded back to a 2-class + map encoded in the same spread space as the GT (correct -> low value, + wrong -> 255), background -> 0. """ from rationai.masks.heatmap_assembler import HeatmapAssembler @@ -390,14 +396,15 @@ def _write_error_map( torch.from_numpy(xs), torch.from_numpy(ys), ) - grid = (assembler.compute() >= 0.5).to(torch.uint8).numpy() - grid[assembler._count.numpy() == 0] = background_value + wrong = (assembler.compute() >= 0.5).numpy().astype(np.intp) + grid = _spread_lut(2)[wrong] # correct -> 128, wrong -> 255 + grid[assembler._count.numpy() == 0] = 0 _emit_mask( np.ascontiguousarray(grid), assembler.common_divisor_x, assembler.common_divisor_y, extent, - background_value, + 0, path, mpp, ) @@ -437,3 +444,17 @@ def _resolve_uri(uri: str) -> str: def _safe_filename(value: str) -> str: return sub(r"[^A-Za-z0-9_.-]+", "_", value) + + +def _spread_lut(n_classes: int) -> np.ndarray: + """0-based class index -> evenly-spread uint8 value. + + Mirrors ``preprocessing/remap_annotation_masks.py`` exactly so prediction + masks share the GT value space the reporting tool expects: + class ``i`` -> ``round(255 * (i + 1) / n_classes)``. Index 0 of the + returned array is unused by callers (background is set separately to 0). + """ + return np.array( + [round(255 * (i + 1) / n_classes) for i in range(n_classes)], + dtype=np.uint8, + ) diff --git a/scripts/submit_test_linear.py b/scripts/submit_test_linear.py index c7860485..4af3904c 100644 --- a/scripts/submit_test_linear.py +++ b/scripts/submit_test_linear.py @@ -1,21 +1,21 @@ from kube_jobs import storage, submit_job -checkpoint = "mlflow-artifacts:/104//artifacts/checkpoints/last/checkpoint.ckpt" +checkpoint = "mlflow-artifacts:/104/a23e478b00b04da79cfbf4d91cada8cd/artifacts/checkpoints/last/checkpoint.ckpt" submit_job( - job_name="tissue-classification-test-linear-final-...", - username=..., + job_name="tissue-classification-test-linear-final-adamw", + username="vcifka", cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", + "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - f'uv run python -m ml +experiment=... checkpoint=\\"{checkpoint}\\"', + f'PYTHONUNBUFFERED=1 uv run python -m ml +experiment=ml/linear_classifier_test_adamw mode=test checkpoint=\\"{checkpoint}\\"', ], storage=[storage.secure.PROJECTS], ) From 918b691168b07c327b525e0bba943a3986e03931 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 16 May 2026 17:30:30 +0200 Subject: [PATCH 085/107] fix: format and mypy --- ml/callbacks/tiff_prediction_map_writer.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index 4e58b700..6e302a61 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -200,7 +200,9 @@ def _write_slide_maps( xs = predictions["x"].to_numpy(dtype=np.int64) ys = predictions["y"].to_numpy(dtype=np.int64) - probs = np.stack(predictions["probs"].to_numpy()).astype(np.float32) + probs = np.stack( + [np.asarray(prob, dtype=np.float32) for prob in predictions["probs"]] + ) _write_class_map( probs=probs, @@ -263,14 +265,10 @@ def __init__( self._gx = extent_x // self.cdx self._gy = extent_y // self.cdy self._n_classes = n_classes - self._acc = torch.zeros( - n_classes, self._gy, self._gx, dtype=torch.float32 - ) + self._acc = torch.zeros(n_classes, self._gy, self._gx, dtype=torch.float32) self._count = torch.zeros(self._gy, self._gx, dtype=torch.int32) - def update( - self, probs: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor - ) -> None: + def update(self, probs: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor) -> None: for prob, x, y in zip(probs, xs, ys, strict=False): cx = int(x.item()) // self.cdx cy = int(y.item()) // self.cdy @@ -283,8 +281,10 @@ def update( self._count[y0:y1, x0:x1] += 1 def labels(self) -> np.ndarray: - """Encode like ``remap_annotation_masks``: class ``i`` (0-based) -> - ``round(255 * (i + 1) / n_classes)``, never-covered pixels -> ``0``. + """Encode prediction labels like ``remap_annotation_masks``. + + Class ``i`` (0-based) maps to + ``round(255 * (i + 1) / n_classes)``. Never-covered pixels map to ``0``. The reporting tool expects GT and prediction masks in the same evenly-spread value space, so this must mirror that LUT exactly. @@ -314,9 +314,7 @@ def _emit_mask( path.parent.mkdir(parents=True, exist_ok=True) mask = pyvips.Image.new_from_array(grid).cast(pyvips.BandFormat.UCHAR) - mask = mask.resize( - cdx, vscale=cdy, kernel=pyvips.enums.Kernel.NEAREST - ) + mask = mask.resize(cdx, vscale=cdy, kernel=pyvips.enums.Kernel.NEAREST) mask = mask.embed( 0, 0, From 4032df39eb8cc9fed2956b15a1967939f8a9967e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sun, 17 May 2026 10:32:45 +0200 Subject: [PATCH 086/107] feat: add posibility to predict the whole slide with tissue area --- configs/data/dataset.yaml | 2 +- ...inear_classifier_predict_tissue_tiles.yaml | 43 ++++++ .../ml/linear_classifier_test_adamw.yaml | 11 -- .../ml/linear_classifier_test_lbfgs.yaml | 1 + ...mbeddings_virchow2_tissue_tiles_05mpp.yaml | 17 +++ configs/preprocessing/embeddings.yaml | 10 ++ docs/stratified_group_kfold.md | 141 ------------------ ml/data/datasets/__init__.py | 7 +- ml/data/datasets/embedding_tiles.py | 118 +++++++++++++++ scripts/submit_predict_tissue_tiles.py | 29 ++++ scripts/submit_test_linear.py | 10 +- 11 files changed, 229 insertions(+), 160 deletions(-) create mode 100644 configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml create mode 100644 configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_05mpp.yaml delete mode 100644 docs/stratified_group_kfold.md create mode 100644 scripts/submit_predict_tissue_tiles.py diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 09f8f4a1..f0613710 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,7 +14,7 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" - stratified_kfold_run_id: "c7eafdffa32743aa9eb6dd2bf3a185b5" + stratified_kfold_run_id: "850c81506684450b9af92296acfd045a" stratified_group_kfold_run_id: "382b41d2fa894514908e8067949c4326" embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" diff --git a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml new file mode 100644 index 00000000..7d0d6623 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml @@ -0,0 +1,43 @@ +# @package _global_ + +defaults: + - /ml/linear_classifier_final + - _self_ + +# Predict over every test-split tile that intersects the tissue mask. This run +# is unlabeled: it writes predictions/maps for review and does not compute test +# metrics. +mode: predict + +tissue_embedding_run_id: ??? +tissue_stats_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id} +tissue_stats_artifact_path: tissue_stats +tissue_column: tile_tissue_coverage +tissue_min: 0.0 + +test_embedding_uri: runs:/${tissue_embedding_run_id}/test/tiles +test_metadata_uri: runs:/${tissue_stats_run_id}/${tissue_stats_artifact_path}/test_tiles.parquet + +data: + num_workers: 0 + predict: + _target_: ml.data.datasets.UnlabeledEmbeddingTilesDataset + embedding_uri: ${test_embedding_uri} + metadata_uri: ${test_metadata_uri} + tissue_column: ${tissue_column} + tissue_min: ${tissue_min} + +trainer: + callbacks: + tiff_prediction_maps: + write_errors: false + +metadata: + run_name: Predict Linear Classifier tissue tiles ${dataset.name} + description: "Predict final linear classifier over all held-out test tiles intersecting tissue masks for external doctor review." + hyperparams: + tissue_embedding_run_id: ${tissue_embedding_run_id} + tissue_stats_run_id: ${tissue_stats_run_id} + tissue_stats_artifact_path: ${tissue_stats_artifact_path} + tissue_column: ${tissue_column} + tissue_min: ${tissue_min} diff --git a/configs/experiment/ml/linear_classifier_test_adamw.yaml b/configs/experiment/ml/linear_classifier_test_adamw.yaml index b276759b..c3400d79 100644 --- a/configs/experiment/ml/linear_classifier_test_adamw.yaml +++ b/configs/experiment/ml/linear_classifier_test_adamw.yaml @@ -8,9 +8,6 @@ defaults: # architecture as the final run (required for state_dict load); optimizer # fields are inert at test. Checkpoint is passed on the CLI. # -# The AdamW final inherits trainer/default (early stopping for fit), which has -# no TIFF map writer. Add it here so the AdamW test produces the same -# WSI-aligned prediction maps as the LBFGS test. mode: test # num_workers MUST stay 0. EmbeddingTilesDataset loads the entire split into @@ -20,11 +17,3 @@ mode: test # before the first test batch. embedding_final defaults to 4; override here. data: num_workers: 0 - -trainer: - callbacks: - tiff_prediction_maps: - _target_: ml.callbacks.TiffPredictionMapWriter - slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet - artifact_path: prediction_maps_tiff - draw_region: central_stride diff --git a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml index 1efd0765..d455774d 100644 --- a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml +++ b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml @@ -2,6 +2,7 @@ defaults: - /experiment/ml/linear_classifier_final_lbfgs + - override /ml/trainer: default - _self_ # Test the LBFGS final checkpoint on the held-out test split. The full-batch diff --git a/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_05mpp.yaml b/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_05mpp.yaml new file mode 100644 index 00000000..4ea45ff7 --- /dev/null +++ b/configs/experiment/preprocessing/embeddings_virchow2_tissue_tiles_05mpp.yaml @@ -0,0 +1,17 @@ +# @package _global_ + +defaults: + - /experiment/preprocessing/embeddings_virchow2_05mpp + - _self_ + +# Embeddings for every tile in the train/test tiling split that intersects the +# tissue mask. This is the tile universe used for doctor-review prediction maps. +splits: + - test +tile_source_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id} +tile_source_artifact_template: "tissue_stats/{split}_tiles.parquet" +tile_filter_column: tile_tissue_coverage + +metadata: + run_name: "Embeddings: ${model} tissue tiles" + description: "Tile embeddings using ${model} over held-out test split tiles with tile_tissue_coverage > 0." diff --git a/configs/preprocessing/embeddings.yaml b/configs/preprocessing/embeddings.yaml index db9ceaa4..91400ac6 100644 --- a/configs/preprocessing/embeddings.yaml +++ b/configs/preprocessing/embeddings.yaml @@ -5,6 +5,12 @@ output_dir: ${project_path}/embeddings concurrency: 512 block_size: 2048 rows_per_file: 5000 +splits: + - train + - test +tile_source_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} +tile_source_artifact_template: "filter_tiles/{split}_tiles.parquet" +tile_filter_column: null metadata: run_name: "Embeddings: ${model}" @@ -13,3 +19,7 @@ metadata: model: ${model} concurrency: ${concurrency} block_size: ${block_size} + splits: ${splits} + tile_source_run_id: ${tile_source_run_id} + tile_source_artifact_template: ${tile_source_artifact_template} + tile_filter_column: ${tile_filter_column} diff --git a/docs/stratified_group_kfold.md b/docs/stratified_group_kfold.md deleted file mode 100644 index c5e2ffa1..00000000 --- a/docs/stratified_group_kfold.md +++ /dev/null @@ -1,141 +0,0 @@ -# Stratified Group K-Fold Split - -## Motivation - -The original tile-level `StratifiedKFold` split balanced tissue labels well, but it allowed tiles from the same slide to appear in both training and validation partitions. Because tiles from a single whole-slide image are not independent, this can leak slide-specific visual patterns into validation and make validation performance overly optimistic. - -To reduce this leakage risk, the splitter now supports `kfold_strategy: stratified_group`, implemented in `split/kfold_split.py`. This mode uses `StratifiedGroupKFold`, stratifying by tissue label while treating `slide_id` as the grouping variable. As a result, all tiles from the same slide are assigned to exactly one validation fold. - -The original tile-level `StratifiedKFold` strategy is still available as `kfold_strategy: stratified`. It can be run when slide-level separation is not required, for example for debugging, comparison against older experiments, or workflows where tile-level stratification is intentionally preferred. - -## Stratification Target and Rare-Class Protocol - -Labels are derived per tile from the `roi_coverage_*` columns: - -- `label` is the tissue class with the highest ROI coverage. -- `background` is assigned when a tile has zero ROI coverage. -- `tissue_prop` is the sum of all `roi_coverage_*` values for the tile. - -For grouped splitting, the important constraint is no longer only the number of tiles per class. `StratifiedGroupKFold` also needs each stratification class to be represented across enough distinct groups. In this project, a group is a slide, so each retained class must appear in at least `n_folds` distinct `slide_id` values. - -The rare-class protocol for `stratified_group` is therefore slide-based: - -- The splitter counts the number of distinct slides containing each label. -- Any label present in fewer than `n_folds` slides is considered rare for grouped splitting. -- All tiles with rare labels are dropped before fold assignment. -- A warning lists each dropped label and the number of slides in which it appears. -- If the rare-class filtering would drop every tile, the script raises a `ValueError`. - -This differs from the older `stratified` strategy. The tile-level strategy collapses rare tile-count classes into `background` only for stratification. The grouped strategy does not collapse rare classes into `background`, because background can be sparse or filtered upstream and because collapsing would not reliably solve the slide-level group constraint. - -## Fold Assignment - -`StratifiedGroupKFold(n_splits=n_folds, shuffle=True, random_state=...)` is fitted with: - -- `y`: the derived tile label. -- `groups`: the tile `slide_id`. - -Each tile is assigned to its validation fold, `fold in [0, n_folds)`. For any fold `k`, the validation set is `fold == k` and the training set is the complement, `fold != k`. - -The group constraint means that each slide appears in only one validation fold. Consequently, no tile from a validation slide appears in the corresponding training split. - -## Output - -The script writes one parquet artifact, `kfold_tiles.parquet`, under the configured `mlflow_artifact_path`. - -The output keeps the filtered input tile dataset and adds fold metadata: - -| Column | Type | Source | -| --- | --- | --- | -| `tissue_prop` | float | Sum of `roi_coverage_*` columns. | -| `fold` | int8 | Validation fold index in `[0, n_folds)`. | - -For `stratified_group`, rare labels may be removed before writing the parquet. When this happens, the logged metric `dropped_rare_class_tiles` records how many tiles were excluded. - -Note: labels are derived inside the splitter for stratification and statistics. The current implementation does not add a new `label` column to the output parquet unless such a column is already present in the input dataset. - -## Logged Statistics - -Per-fold metrics are emitted to MLflow: - -- `fold__train_tiles`: number of tiles outside validation fold `k`. -- `fold__val_tiles`: number of validation tiles in fold `k`. -- `fold__val_tile_pct`: fraction of retained tiles assigned to validation fold `k`. -- `fold__val_slides`: number of distinct validation slides in fold `k`. -- `fold__val_tissue_prop_mean`: mean tissue coverage in validation fold `k`. -- `fold__val_tissue_prop_std`: tissue coverage standard deviation in validation fold `k`. -- `fold_size_cv`: coefficient of variation of validation fold sizes. -- `dropped_rare_class_tiles`: number of dropped rare-class tiles, logged only when rare-class filtering removes tiles. - -The script also logs a label-distribution table as an MLflow artifact: - -- `fold_statistics/label_distribution.json`: fold by original derived label counts. - -Unlike the tile-level `stratified` strategy, the `stratified_group` strategy does not log `fold_statistics/stratification_label_distribution.json`, because it uses the original derived labels directly and does not create a separate collapsed stratification-label array. - -## Split Statistics - -Detailed JSON representations of the metrics are available within the respective MLflow run artifacts. - -### Global Metrics - -- Total retained tiles: 1,102,086 -- Original tiles before rare-class filtering: 1,102,086 -- Dropped rare-class tiles: 0 -- n_folds: 5 -- Random state: 42 -- K-fold strategy: `stratified_group` -- Rare labels dropped before splitting: none -- Fold size CV: 0.0517 - -### Per-Fold Metrics - -| Fold | Train tiles | Val tiles | Val % | Val slides | tissue_prop mean +- std | -| --- | --- | --- | --- | --- | --- | -| 0 | 859,277 | 242,809 | 22.03% | 26 | 0.9077 +- 0.2667 | -| 1 | 890,838 | 211,248 | 19.17% | 26 | 0.8809 +- 0.2981 | -| 2 | 884,420 | 217,666 | 19.75% | 27 | 0.8979 +- 0.2797 | -| 3 | 886,197 | 215,889 | 19.59% | 27 | 0.8705 +- 0.3084 | -| 4 | 887,612 | 214,474 | 19.46% | 31 | 0.8812 +- 0.2978 | - -### Original Label Distribution per Fold - -For `stratified_group`, this table reflects the labels that were retained after rare-class filtering. - -| Label | Fold 0 | Fold 1 | Fold 2 | Fold 3 | Fold 4 | -| --- | --- | --- | --- | --- | --- | -| background | 4.4941% | 6.0493% | 5.2140% | 6.5742% | 6.1728% | -| Blood | 0.4156% | 0.5027% | 0.5035% | 0.4530% | 0.9059% | -| Connective-Tissue | 3.1514% | 3.4386% | 2.9573% | 3.3703% | 3.8252% | -| Epithelium | 1.2413% | 1.4268% | 1.3948% | 1.3887% | 1.4118% | -| Fat | 10.8517% | 10.8948% | 13.8033% | 12.2832% | 10.8428% | -| Muscle | 14.7194% | 2.9955% | 3.2605% | 2.4929% | 3.3776% | -| Nerve | 1.6153% | 1.8968% | 1.8556% | 1.7856% | 1.8725% | -| Other | 63.5112% | 72.7955% | 71.0111% | 71.6521% | 71.5914% | - -### Slide Distribution per Fold - -Use this table to document how slides are distributed across validation folds. This is the key leakage-control diagnostic for the grouped split. - -| Fold | Val slides | Val slide % | Val tiles | Val tile % | -| --- | --- | --- | --- | --- | -| 0 | 26 | 18.98% | 242,809 | 22.03% | -| 1 | 26 | 18.98% | 211,248 | 19.17% | -| 2 | 27 | 19.71% | 217,666 | 19.75% | -| 3 | 27 | 19.71% | 215,889 | 19.59% | -| 4 | 31 | 22.63% | 214,474 | 19.46% | - -### Optional: Label Counts per Fold - -Use this table if you want to report absolute counts in addition to percentages. - -| Label | Fold 0 | Fold 1 | Fold 2 | Fold 3 | Fold 4 | Total | -| --- | --- | --- | --- | --- | --- | --- | -| background | 10,912 | 12,779 | 11,349 | 14,193 | 13,239 | 62,472 | -| Blood | 1,009 | 1,062 | 1,096 | 978 | 1,943 | 6,088 | -| Connective-Tissue | 7,652 | 7,264 | 6,437 | 7,276 | 8,204 | 36,833 | -| Epithelium | 3,014 | 3,014 | 3,036 | 2,998 | 3,028 | 15,090 | -| Fat | 26,349 | 23,015 | 30,045 | 26,518 | 23,255 | 129,182 | -| Muscle | 35,740 | 6,328 | 7,097 | 5,382 | 7,244 | 61,791 | -| Nerve | 3,922 | 4,007 | 4,039 | 3,855 | 4,016 | 19,839 | -| Other | 154,211 | 153,779 | 154,567 | 154,689 | 153,545 | 770,791 | diff --git a/ml/data/datasets/__init__.py b/ml/data/datasets/__init__.py index cd2f91a9..f3a5b475 100644 --- a/ml/data/datasets/__init__.py +++ b/ml/data/datasets/__init__.py @@ -1,4 +1,7 @@ -from ml.data.datasets.embedding_tiles import EmbeddingTilesDataset +from ml.data.datasets.embedding_tiles import ( + EmbeddingTilesDataset, + UnlabeledEmbeddingTilesDataset, +) -__all__ = ["EmbeddingTilesDataset"] +__all__ = ["EmbeddingTilesDataset", "UnlabeledEmbeddingTilesDataset"] diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 0fd864c0..6a8a2a0f 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -213,3 +213,121 @@ def _resolve_uri_cached(uri: str) -> str: if uri.startswith(("mlflow-artifacts:/", "runs:/")): return download_artifacts(artifact_uri=uri) return uri + + +class UnlabeledEmbeddingTilesDataset(Dataset[Sample]): + """Tile-embedding dataset for prediction over tiles without class labels.""" + + def __init__( + self, + embedding_uri: str | Path, + metadata_uri: str | Path, + tissue_column: str = "tile_tissue_coverage", + tissue_min: float = 0.0, + label_value: int = -1, + ) -> None: + t0 = perf_counter() + + def _diag(msg: str) -> None: + print( + f"[UnlabeledEmbeddingTilesDataset +{perf_counter() - t0:6.1f}s] {msg}", + flush=True, + ) + + _diag("filtering metadata") + meta_df = self._filter_metadata(metadata_uri, tissue_column, tissue_min) + _diag(f"metadata filtered: {len(meta_df)} rows; reading embeddings") + + emb_dir = EmbeddingTilesDataset._resolve_uri(embedding_uri) + emb_table = pads.dataset(emb_dir, format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) + _diag(f"embedding table loaded: {emb_table.num_rows} rows") + + emb_col = emb_table.column("embedding") + if pa.types.is_list(emb_col.type): + target_type = pa.large_list(emb_col.type.value_type) + emb_col = pa.chunked_array( + [c.cast(target_type) for c in emb_col.chunks], type=target_type + ) + + emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) + emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) + del emb_table + + meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) + del meta_df + joined_keys = meta_table.join( + emb_keys, keys=["slide_id", "x", "y"], join_type="inner" + ) + del emb_keys, meta_table + if joined_keys.num_rows == 0: + raise RuntimeError("inner join with embeddings produced empty dataset") + _diag(f"join done: {joined_keys.num_rows} matched rows; filling embeddings") + + _idx_col = joined_keys.column("_emb_idx") + if isinstance(_idx_col, pa.ChunkedArray): + _idx_col = _idx_col.combine_chunks() + indices_np = _idx_col.to_numpy() + + first_chunk = emb_col.chunks[0] + embedding_dim = len(first_chunk.values) // len(first_chunk) + sort_order = np.argsort(indices_np) + sorted_indices = indices_np[sort_order] + + chunk_offsets = np.concatenate( + [[0], np.cumsum([len(c) for c in emb_col.chunks])] + ) + embeddings = np.empty((len(indices_np), embedding_dim), dtype=np.float32) + for ci, chunk in enumerate(emb_col.chunks): + lo, hi = chunk_offsets[ci], chunk_offsets[ci + 1] + mask = (sorted_indices >= lo) & (sorted_indices < hi) + if not mask.any(): + continue + local_idx = sorted_indices[mask] - lo + chunk_np = ( + chunk.values.to_numpy(zero_copy_only=False) + .reshape(len(chunk), embedding_dim) + .astype(np.float32) + ) + embeddings[sort_order[mask]] = chunk_np[local_idx] + del emb_col + + self.embeddings = embeddings + self.labels = np.full(len(joined_keys), label_value, dtype=np.int64) + self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() + self.xs = joined_keys.column("x").to_pandas().to_numpy(dtype=np.int64) + self.ys = joined_keys.column("y").to_pandas().to_numpy(dtype=np.int64) + _diag(f"dataset ready: {len(self.labels)} samples, dim={embeddings.shape[1]}") + + def __len__(self) -> int: + return len(self.labels) + + def __getitem__(self, idx: int) -> Sample: + return ( + torch.from_numpy(self.embeddings[idx]), + int(self.labels[idx]), + str(self.slide_ids[idx]), + int(self.xs[idx]), + int(self.ys[idx]), + ) + + @staticmethod + def _filter_metadata( + metadata_uri: str | Path, + tissue_column: str, + tissue_min: float, + ) -> pd.DataFrame: + local = EmbeddingTilesDataset._resolve_uri(metadata_uri) + columns = ["slide_id", "x", "y", tissue_column] + df = pd.read_parquet(local, columns=columns) + if tissue_column not in df.columns: + raise ValueError( + f"metadata parquet has no {tissue_column!r} column; cannot filter" + ) + df = df.loc[df[tissue_column] > tissue_min, ["slide_id", "x", "y"]] + if df.empty: + raise RuntimeError( + f"all tiles dropped by {tissue_column} > {tissue_min} filter" + ) + return df diff --git a/scripts/submit_predict_tissue_tiles.py b/scripts/submit_predict_tissue_tiles.py new file mode 100644 index 00000000..38010592 --- /dev/null +++ b/scripts/submit_predict_tissue_tiles.py @@ -0,0 +1,29 @@ +from kube_jobs import storage, submit_job + + +# Final probe checkpoint to predict with (same convention as submit_test_linear). +checkpoint = "mlflow-artifacts:/104/0e2230c722134ce0985e09a18ccadf75/artifacts/checkpoints/last/checkpoint.ckpt" + +# MLflow run of embeddings_virchow2_tissue_tiles_05mpp (all test-split tiles +# intersecting the tissue mask). Fill after that preprocessing run completes. +tissue_embedding_run_id = "FILL_ME" + + +# Predicts over every test tile intersecting the tissue mask (no labels, no +# metrics). Loads all tissue-tile embeddings into one in-memory array, so this +# needs more memory than the annotated-only test job. +submit_job( + job_name="tissue-classification-predict-tissue-tiles", + username="vcifka", + cpu=8, + memory="128Gi", + gpu=None, + public=False, + script=[ + "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + f'PYTHONUNBUFFERED=1 uv run python -m ml +experiment=ml/linear_classifier_predict_tissue_tiles mode=predict checkpoint=\\"{checkpoint}\\" tissue_embedding_run_id={tissue_embedding_run_id}', + ], + storage=[storage.secure.PROJECTS], +) diff --git a/scripts/submit_test_linear.py b/scripts/submit_test_linear.py index 4af3904c..6d2931af 100644 --- a/scripts/submit_test_linear.py +++ b/scripts/submit_test_linear.py @@ -1,21 +1,21 @@ from kube_jobs import storage, submit_job -checkpoint = "mlflow-artifacts:/104/a23e478b00b04da79cfbf4d91cada8cd/artifacts/checkpoints/last/checkpoint.ckpt" +checkpoint = "mlflow-artifacts:/104//artifacts/checkpoints/last/checkpoint.ckpt" submit_job( - job_name="tissue-classification-test-linear-final-adamw", - username="vcifka", + job_name="tissue-classification-test-linear-final", + username=..., cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", + "git https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - f'PYTHONUNBUFFERED=1 uv run python -m ml +experiment=ml/linear_classifier_test_adamw mode=test checkpoint=\\"{checkpoint}\\"', + f'uv run python -m ml +experiment=... mode=test checkpoint=\\"{checkpoint}\\"', ], storage=[storage.secure.PROJECTS], ) From ca50a7c78edc05987a74ed0000e9583f3cd2d409 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sun, 17 May 2026 10:33:07 +0200 Subject: [PATCH 087/107] feat: add embeddings for whole slide --- preprocessing/embeddings.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/preprocessing/embeddings.py b/preprocessing/embeddings.py index de981435..7ad74dda 100644 --- a/preprocessing/embeddings.py +++ b/preprocessing/embeddings.py @@ -55,7 +55,15 @@ async def __call__(self, row: dict[str, Any]) -> dict[str, Any]: @hydra.main(config_path="../configs", config_name="preprocessing", version_base=None) @autolog def main(config: DictConfig, logger: MLFlowLogger) -> None: - for name in ["train", "test"]: + tile_source_run_id = config.get( + "tile_source_run_id", config.dataset.mlflow_artifacts.filter_tiles_run_id + ) + tile_source_artifact_template = config.get( + "tile_source_artifact_template", "filter_tiles/{split}_tiles.parquet" + ) + tile_filter_column = config.get("tile_filter_column") + + for name in config.get("splits", ["train", "test"]): split_folder = Path( mlflow.artifacts.download_artifacts( run_id=config.dataset.mlflow_artifacts.tiling_run_id, @@ -69,19 +77,35 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: tiles_path = Path( mlflow.artifacts.download_artifacts( - run_id=config.dataset.mlflow_artifacts.filter_tiles_run_id, - artifact_path=f"filter_tiles/{name}_tiles.parquet", + run_id=tile_source_run_id, + artifact_path=tile_source_artifact_template.format(split=name), ) ) - num_rows = pads.dataset(str(tiles_path), format="parquet").count_rows() + row_filter = ( + pads.field(tile_filter_column) > 0 + if tile_filter_column is not None + else None + ) + num_rows = pads.dataset(str(tiles_path), format="parquet").count_rows( + filter=row_filter + ) num_blocks = max(1, num_rows // config.block_size) + columns = ["slide_id", "x", "y"] + if tile_filter_column is not None: + columns.append(tile_filter_column) ds = ray.data.read_parquet( str(tiles_path), - columns=["slide_id", "x", "y"], + columns=columns, ray_remote_args={"memory": 8 * 1024**3}, override_num_blocks=num_blocks, - ).map( + ) + if tile_filter_column is not None: + ds = ds.filter( + lambda row, c: row[c] > 0, fn_kwargs={"c": tile_filter_column} + ) + ds = ds.drop_columns([tile_filter_column]) + ds = ds.map( lambda row, si: {**row, **si[row["slide_id"]]}, fn_kwargs={"si": slide_info}, ) From 6489cd02aa353b28879306853a836da436a99417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sun, 17 May 2026 10:51:46 +0200 Subject: [PATCH 088/107] refactor: compute grayscale mask per each class --- ml/callbacks/tiff_prediction_map_writer.py | 78 +++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index 6e302a61..f8c6d14e 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -126,6 +126,7 @@ def _write_maps(self, trainer: pl.Trainer) -> None: Path(output_path, "pred").mkdir(parents=True, exist_ok=True) if self.write_errors: Path(output_path, "errors").mkdir(parents=True, exist_ok=True) + Path(output_path, "prob").mkdir(parents=True, exist_ok=True) slide_groups = self._select_slide_groups(predictions) print( @@ -145,7 +146,12 @@ def _write_maps(self, trainer: pl.Trainer) -> None: f"[TiffPredictionMapWriter] {index}/{len(slide_groups)} " f"{Path(str(slide['path'])).name}" ) - self._write_slide_maps(slide, slide_predictions, output_path) + self._write_slide_maps( + slide, + slide_predictions, + output_path, + class_names=getattr(trainer.lightning_module, "class_names", None), + ) active = mlflow.active_run() if active is not None: @@ -191,6 +197,7 @@ def _write_slide_maps( slide: dict[str, Any], predictions: pd.DataFrame, output_path: Path, + class_names: list[str] | None, ) -> None: filename = f"{_safe_filename(Path(str(slide['path'])).stem)}.tiff" extent = (int(slide["extent_x"]), int(slide["extent_y"])) @@ -215,6 +222,24 @@ def _write_slide_maps( mpp=mpp, ) + names = ( + class_names + if class_names is not None and len(class_names) == probs.shape[1] + else [f"class_{i}" for i in range(probs.shape[1])] + ) + _write_per_class_probability_maps( + probs=probs, + class_names=names, + xs=xs, + ys=ys, + extent=extent, + tile_extent=tile_extent, + stride=stride, + output_dir=Path(output_path, "prob"), + filename=filename, + mpp=mpp, + ) + if not self.write_errors: return @@ -408,6 +433,57 @@ def _write_error_map( ) +def _write_per_class_probability_maps( + probs: np.ndarray, + class_names: list[str], + xs: np.ndarray, + ys: np.ndarray, + extent: tuple[int, int], + tile_extent: tuple[int, int], + stride: tuple[int, int], + output_dir: Path, + filename: str, + mpp: tuple[float, float], +) -> None: + """Write one grayscale probability map per class. + + Each class channel is assembled independently with HeatmapAssembler, so + overlapping tiles are averaged exactly like scalar heatmaps. Values are + encoded as uint8 probabilities in [0, 255], with never-covered pixels set + to 0. + """ + from rationai.masks.heatmap_assembler import HeatmapAssembler + + xs_t = torch.from_numpy(xs) + ys_t = torch.from_numpy(ys) + for class_idx, class_name in enumerate(class_names): + assembler = HeatmapAssembler( + extent[0], + extent[1], + tile_extent[0], + tile_extent[1], + stride[0], + stride[1], + dtype=torch.float32, + ) + assembler.update( + torch.from_numpy(probs[:, class_idx].astype(np.float32, copy=False)), + xs_t, + ys_t, + ) + grid = np.clip(assembler.compute().numpy() * 255.0, 0, 255).astype(np.uint8) + grid[assembler._count.numpy() == 0] = 0 + _emit_mask( + np.ascontiguousarray(grid), + assembler.common_divisor_x, + assembler.common_divisor_y, + extent, + 0, + output_dir / _safe_filename(class_name) / filename, + mpp, + ) + + def _to_cpu_batch(batch: Mapping[str, Any]) -> dict[str, Any]: return { key: value.detach().cpu() if isinstance(value, torch.Tensor) else value From 9d8729a337224e0f687802454c66f31d88524832 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 11:26:25 +0200 Subject: [PATCH 089/107] refactor: do not generate error masks --- ml/callbacks/tiff_prediction_map_writer.py | 68 ---------------------- 1 file changed, 68 deletions(-) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index f8c6d14e..a19d641b 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -29,7 +29,6 @@ def __init__( artifact_path: str = "prediction_maps_tiff", background_value: int = 255, draw_region: str = "central_stride", - write_errors: bool = True, max_slides: int | None = None, slide_selection: str = "all", ) -> None: @@ -51,7 +50,6 @@ def __init__( self.artifact_path = artifact_path self.background_value = background_value self.draw_region = draw_region - self.write_errors = write_errors self.max_slides = max_slides self.slide_selection = slide_selection self._batches: list[dict[str, Any]] = [] @@ -124,8 +122,6 @@ def _write_maps(self, trainer: pl.Trainer) -> None: with TemporaryDirectory(dir=Path(trainer.default_root_dir)) as output_dir: output_path = Path(output_dir) Path(output_path, "pred").mkdir(parents=True, exist_ok=True) - if self.write_errors: - Path(output_path, "errors").mkdir(parents=True, exist_ok=True) Path(output_path, "prob").mkdir(parents=True, exist_ok=True) slide_groups = self._select_slide_groups(predictions) @@ -240,23 +236,6 @@ def _write_slide_maps( mpp=mpp, ) - if not self.write_errors: - return - - errors = ( - predictions["pred"].to_numpy() != predictions["target"].to_numpy() - ).astype(np.float32) - _write_error_map( - errors=errors, - xs=xs, - ys=ys, - extent=extent, - tile_extent=tile_extent, - stride=stride, - path=Path(output_path, "errors", filename), - mpp=mpp, - ) - class _ClassVoteAssembler: """Confidence-weighted per-class accumulator over the full tile footprint. @@ -386,53 +365,6 @@ def _write_class_map( ) -def _write_error_map( - errors: np.ndarray, - xs: np.ndarray, - ys: np.ndarray, - extent: tuple[int, int], - tile_extent: tuple[int, int], - stride: tuple[int, int], - path: Path, - mpp: tuple[float, float], -) -> None: - """Per-tile error (pred != target) averaged over the full tile footprint. - - Averaging a 0/1 error fraction across overlapping tiles is meaningful - (fraction of covering tiles that were wrong); thresholded back to a 2-class - map encoded in the same spread space as the GT (correct -> low value, - wrong -> 255), background -> 0. - """ - from rationai.masks.heatmap_assembler import HeatmapAssembler - - assembler = HeatmapAssembler( - extent[0], - extent[1], - tile_extent[0], - tile_extent[1], - stride[0], - stride[1], - dtype=torch.float32, - ) - assembler.update( - torch.from_numpy(errors), - torch.from_numpy(xs), - torch.from_numpy(ys), - ) - wrong = (assembler.compute() >= 0.5).numpy().astype(np.intp) - grid = _spread_lut(2)[wrong] # correct -> 128, wrong -> 255 - grid[assembler._count.numpy() == 0] = 0 - _emit_mask( - np.ascontiguousarray(grid), - assembler.common_divisor_x, - assembler.common_divisor_y, - extent, - 0, - path, - mpp, - ) - - def _write_per_class_probability_maps( probs: np.ndarray, class_names: list[str], From ac0ce1646d740a964c341752d26b9269c17db2d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 11:56:57 +0200 Subject: [PATCH 090/107] chore: config cleanup --- .../ml/linear_classifier_final_adamw.yaml | 4 +-- .../ml/linear_classifier_final_lbfgs.yaml | 2 +- ...inear_classifier_predict_tissue_tiles.yaml | 21 ++++++++++---- ...ear_classifier_stratified_group_kfold.yaml | 7 ++++- .../linear_classifier_stratified_kfold.yaml | 7 ++++- .../ml/linear_classifier_test_adamw.yaml | 6 ++-- .../ml/linear_classifier_test_lbfgs.yaml | 4 ++- ..._final.yaml => final_embedding_tiles.yaml} | 0 ...edding.yaml => kfold_embedding_tiles.yaml} | 0 configs/ml/model/linear_classifier.yaml | 16 +++------- .../final_linear_classifier.yaml} | 4 +-- .../kfold_linear_classifier.yaml} | 4 +-- .../{default.yaml => early_stopping.yaml} | 0 ...l.yaml => final_with_prediction_maps.yaml} | 1 - scripts/submit_predict_tissue_tiles.py | 29 ------------------- scripts/submit_test_linear.py | 7 ++--- 16 files changed, 48 insertions(+), 64 deletions(-) rename configs/ml/data/{embedding_final.yaml => final_embedding_tiles.yaml} (100%) rename configs/ml/data/{embedding.yaml => kfold_embedding_tiles.yaml} (100%) rename configs/ml/{linear_classifier_final.yaml => task/final_linear_classifier.yaml} (94%) rename configs/ml/{linear_classifier.yaml => task/kfold_linear_classifier.yaml} (96%) rename configs/ml/trainer/{default.yaml => early_stopping.yaml} (100%) rename configs/ml/trainer/{final.yaml => final_with_prediction_maps.yaml} (97%) delete mode 100644 scripts/submit_predict_tissue_tiles.py diff --git a/configs/experiment/ml/linear_classifier_final_adamw.yaml b/configs/experiment/ml/linear_classifier_final_adamw.yaml index 2d0feba8..9cc037f9 100644 --- a/configs/experiment/ml/linear_classifier_final_adamw.yaml +++ b/configs/experiment/ml/linear_classifier_final_adamw.yaml @@ -1,8 +1,8 @@ # @package _global_ defaults: - - /ml/linear_classifier_final - - override /ml/trainer: default + - /ml/task: final_linear_classifier + - override /ml/trainer: early_stopping - _self_ # AdamW final: trained to convergence with the same early-stopping rule as the diff --git a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml index f83f3626..67174b0d 100644 --- a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml +++ b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - /ml/linear_classifier_final + - /ml/task: final_linear_classifier - _self_ # LBFGS final: exact solve of the convex objective on the full training batch. diff --git a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml index 7d0d6623..400be4bf 100644 --- a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml +++ b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml @@ -1,13 +1,15 @@ # @package _global_ defaults: - - /ml/linear_classifier_final + - /ml/task: final_linear_classifier - _self_ # Predict over every test-split tile that intersects the tissue mask. This run # is unlabeled: it writes predictions/maps for review and does not compute test # metrics. mode: predict +final_train_run_id: 0e2230c722134ce0985e09a18ccadf75 +checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt tissue_embedding_run_id: ??? tissue_stats_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id} @@ -27,10 +29,19 @@ data: tissue_column: ${tissue_column} tissue_min: ${tissue_min} -trainer: - callbacks: - tiff_prediction_maps: - write_errors: false +model: + optimizer: lbfgs + learning_rate: 1.0 + weight_decay: 1.0e-2 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false metadata: run_name: Predict Linear Classifier tissue tiles ${dataset.name} diff --git a/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml index 471f5a36..86f56479 100644 --- a/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml +++ b/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml @@ -1,8 +1,13 @@ # @package _global_ defaults: - - /ml/linear_classifier + - /ml/task: kfold_linear_classifier - _self_ kfold_strategy: stratified_group kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} + +model: + optimizer: adamw + learning_rate: 1.0e-4 + weight_decay: 0.0 diff --git a/configs/experiment/ml/linear_classifier_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_stratified_kfold.yaml index c01fbbf9..5d09c027 100644 --- a/configs/experiment/ml/linear_classifier_stratified_kfold.yaml +++ b/configs/experiment/ml/linear_classifier_stratified_kfold.yaml @@ -1,8 +1,13 @@ # @package _global_ defaults: - - /ml/linear_classifier + - /ml/task: kfold_linear_classifier - _self_ kfold_strategy: stratified kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} + +model: + optimizer: adamw + learning_rate: 1.0e-4 + weight_decay: 0.0 diff --git a/configs/experiment/ml/linear_classifier_test_adamw.yaml b/configs/experiment/ml/linear_classifier_test_adamw.yaml index c3400d79..d39cb97b 100644 --- a/configs/experiment/ml/linear_classifier_test_adamw.yaml +++ b/configs/experiment/ml/linear_classifier_test_adamw.yaml @@ -6,14 +6,16 @@ defaults: # Test the AdamW final checkpoint on the held-out test split. Same model # architecture as the final run (required for state_dict load); optimizer -# fields are inert at test. Checkpoint is passed on the CLI. +# fields are inert at test. # mode: test +final_train_run_id: a23e478b00b04da79cfbf4d91cada8cd +checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt # num_workers MUST stay 0. EmbeddingTilesDataset loads the entire split into # one in-memory numpy array in __init__ and __getitem__ is pure numpy indexing # (no per-item IO), so workers give zero speedup; num_workers>0 forks the # parent (pyarrow/mlflow/fsspec thread state + large array) and deadlocks -# before the first test batch. embedding_final defaults to 4; override here. +# before the first test batch. final_embedding_tiles defaults to 4; override here. data: num_workers: 0 diff --git a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml index d455774d..0001eadc 100644 --- a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml +++ b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml @@ -2,7 +2,7 @@ defaults: - /experiment/ml/linear_classifier_final_lbfgs - - override /ml/trainer: default + - override /ml/trainer: early_stopping - _self_ # Test the LBFGS final checkpoint on the held-out test split. The full-batch @@ -17,6 +17,8 @@ defaults: # which deadlocks before the first test batch. # Same model architecture as the final run (required for state_dict load). mode: test +final_train_run_id: 0e2230c722134ce0985e09a18ccadf75 +checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt data: batch_size: 1024 diff --git a/configs/ml/data/embedding_final.yaml b/configs/ml/data/final_embedding_tiles.yaml similarity index 100% rename from configs/ml/data/embedding_final.yaml rename to configs/ml/data/final_embedding_tiles.yaml diff --git a/configs/ml/data/embedding.yaml b/configs/ml/data/kfold_embedding_tiles.yaml similarity index 100% rename from configs/ml/data/embedding.yaml rename to configs/ml/data/kfold_embedding_tiles.yaml diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index 4b4d9e83..2760ab73 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -11,15 +11,7 @@ model: class_indices: ${class_indices} - optimizer: adamw - learning_rate: 1.0e-4 - weight_decay: 0.0 - lbfgs: - max_iter: 100 - max_eval: null - tolerance_grad: 1.0e-7 - tolerance_change: 1.0e-9 - history_size: 100 - line_search_fn: strong_wolfe - accumulate_batches: 1 - accumulate_on_cpu: false + optimizer: ??? + learning_rate: ??? + weight_decay: ??? + lbfgs: null diff --git a/configs/ml/linear_classifier_final.yaml b/configs/ml/task/final_linear_classifier.yaml similarity index 94% rename from configs/ml/linear_classifier_final.yaml rename to configs/ml/task/final_linear_classifier.yaml index d49c2c79..154cec5e 100644 --- a/configs/ml/linear_classifier_final.yaml +++ b/configs/ml/task/final_linear_classifier.yaml @@ -3,8 +3,8 @@ defaults: - /data: dataset - /class_mapping: collapse_alterations_to_other - - /ml/trainer: final - - /ml/data: embedding_final + - /ml/trainer: final_with_prediction_maps + - /ml/data: final_embedding_tiles - /ml/model: linear_classifier - _self_ diff --git a/configs/ml/linear_classifier.yaml b/configs/ml/task/kfold_linear_classifier.yaml similarity index 96% rename from configs/ml/linear_classifier.yaml rename to configs/ml/task/kfold_linear_classifier.yaml index d3393372..859ce685 100644 --- a/configs/ml/linear_classifier.yaml +++ b/configs/ml/task/kfold_linear_classifier.yaml @@ -3,8 +3,8 @@ defaults: - /data: dataset - /class_mapping: collapse_alterations_to_other - - /ml/trainer: default - - /ml/data: embedding + - /ml/trainer: early_stopping + - /ml/data: kfold_embedding_tiles - /ml/model: linear_classifier - _self_ diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/early_stopping.yaml similarity index 100% rename from configs/ml/trainer/default.yaml rename to configs/ml/trainer/early_stopping.yaml diff --git a/configs/ml/trainer/final.yaml b/configs/ml/trainer/final_with_prediction_maps.yaml similarity index 97% rename from configs/ml/trainer/final.yaml rename to configs/ml/trainer/final_with_prediction_maps.yaml index 725803eb..c744d281 100644 --- a/configs/ml/trainer/final.yaml +++ b/configs/ml/trainer/final_with_prediction_maps.yaml @@ -28,4 +28,3 @@ trainer: draw_region: central_stride slide_selection: all max_slides: null - write_errors: true diff --git a/scripts/submit_predict_tissue_tiles.py b/scripts/submit_predict_tissue_tiles.py deleted file mode 100644 index 38010592..00000000 --- a/scripts/submit_predict_tissue_tiles.py +++ /dev/null @@ -1,29 +0,0 @@ -from kube_jobs import storage, submit_job - - -# Final probe checkpoint to predict with (same convention as submit_test_linear). -checkpoint = "mlflow-artifacts:/104/0e2230c722134ce0985e09a18ccadf75/artifacts/checkpoints/last/checkpoint.ckpt" - -# MLflow run of embeddings_virchow2_tissue_tiles_05mpp (all test-split tiles -# intersecting the tissue mask). Fill after that preprocessing run completes. -tissue_embedding_run_id = "FILL_ME" - - -# Predicts over every test tile intersecting the tissue mask (no labels, no -# metrics). Loads all tissue-tile embeddings into one in-memory array, so this -# needs more memory than the annotated-only test job. -submit_job( - job_name="tissue-classification-predict-tissue-tiles", - username="vcifka", - cpu=8, - memory="128Gi", - gpu=None, - public=False, - script=[ - "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - f'PYTHONUNBUFFERED=1 uv run python -m ml +experiment=ml/linear_classifier_predict_tissue_tiles mode=predict checkpoint=\\"{checkpoint}\\" tissue_embedding_run_id={tissue_embedding_run_id}', - ], - storage=[storage.secure.PROJECTS], -) diff --git a/scripts/submit_test_linear.py b/scripts/submit_test_linear.py index 6d2931af..3eccdea7 100644 --- a/scripts/submit_test_linear.py +++ b/scripts/submit_test_linear.py @@ -1,9 +1,6 @@ from kube_jobs import storage, submit_job -checkpoint = "mlflow-artifacts:/104//artifacts/checkpoints/last/checkpoint.ckpt" - - submit_job( job_name="tissue-classification-test-linear-final", username=..., @@ -12,10 +9,10 @@ gpu=None, public=False, script=[ - "git https://github.com/RationAI/tissue-classification.git workdir", + "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - f'uv run python -m ml +experiment=... mode=test checkpoint=\\"{checkpoint}\\"', + "uv run python -m ml +experiment=...", ], storage=[storage.secure.PROJECTS], ) From 099e277027b23de601070b9de84b4ec97d389f2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 11:57:54 +0200 Subject: [PATCH 091/107] feat: add prints to the prediction maps writer --- ml/callbacks/tiff_prediction_map_writer.py | 31 ++++++++++++++++++---- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index a19d641b..1a5924e9 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -127,7 +127,8 @@ def _write_maps(self, trainer: pl.Trainer) -> None: slide_groups = self._select_slide_groups(predictions) print( f"[TiffPredictionMapWriter] writing {len(slide_groups)} " - f"prediction map(s)" + f"prediction map(s)", + flush=True, ) for index, (slide_id, slide_predictions) in enumerate( slide_groups, start=1 @@ -140,7 +141,8 @@ def _write_maps(self, trainer: pl.Trainer) -> None: ) print( f"[TiffPredictionMapWriter] {index}/{len(slide_groups)} " - f"{Path(str(slide['path'])).name}" + f"{Path(str(slide['path'])).name}", + flush=True, ) self._write_slide_maps( slide, @@ -153,7 +155,8 @@ def _write_maps(self, trainer: pl.Trainer) -> None: if active is not None: print( f"[TiffPredictionMapWriter] logging artifacts to " - f"{self.artifact_path}" + f"{self.artifact_path}", + flush=True, ) mlflow.log_artifacts(output_dir, artifact_path=self.artifact_path) @@ -207,6 +210,11 @@ def _write_slide_maps( [np.asarray(prob, dtype=np.float32) for prob in predictions["probs"]] ) + pred_path = Path(output_path, "pred", filename) + print( + f"[TiffPredictionMapWriter] writing pred/{filename}", + flush=True, + ) _write_class_map( probs=probs, xs=xs, @@ -214,9 +222,13 @@ def _write_slide_maps( extent=extent, tile_extent=tile_extent, stride=stride, - path=Path(output_path, "pred", filename), + path=pred_path, mpp=mpp, ) + print( + f"[TiffPredictionMapWriter] wrote pred/{filename}", + flush=True, + ) names = ( class_names @@ -389,6 +401,11 @@ def _write_per_class_probability_maps( xs_t = torch.from_numpy(xs) ys_t = torch.from_numpy(ys) for class_idx, class_name in enumerate(class_names): + class_dir = _safe_filename(class_name) + print( + f"[TiffPredictionMapWriter] writing prob/{class_dir}/{filename}", + flush=True, + ) assembler = HeatmapAssembler( extent[0], extent[1], @@ -411,9 +428,13 @@ def _write_per_class_probability_maps( assembler.common_divisor_y, extent, 0, - output_dir / _safe_filename(class_name) / filename, + output_dir / class_dir / filename, mpp, ) + print( + f"[TiffPredictionMapWriter] wrote prob/{class_dir}/{filename}", + flush=True, + ) def _to_cpu_batch(batch: Mapping[str, Any]) -> dict[str, Any]: From 490932469431aec97fd1a8eb0ebc61d90ee2ec42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 11:59:30 +0200 Subject: [PATCH 092/107] feat: add embeddings run id for the whole tissue tiles run --- .../experiment/ml/linear_classifier_predict_tissue_tiles.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml index 400be4bf..56845fa3 100644 --- a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml +++ b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml @@ -11,7 +11,7 @@ mode: predict final_train_run_id: 0e2230c722134ce0985e09a18ccadf75 checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt -tissue_embedding_run_id: ??? +tissue_embedding_run_id: 95a02c93c164415e94702ad5c83ccca2 tissue_stats_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id} tissue_stats_artifact_path: tissue_stats tissue_column: tile_tissue_coverage From ee9d2dacb13ffada4a5fcde55e33abda382469a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 12:07:48 +0200 Subject: [PATCH 093/107] feat: add prediction maps in configs --- .../experiment/ml/linear_classifier_test_adamw.yaml | 10 ++++++++++ .../experiment/ml/linear_classifier_test_lbfgs.yaml | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/configs/experiment/ml/linear_classifier_test_adamw.yaml b/configs/experiment/ml/linear_classifier_test_adamw.yaml index d39cb97b..6c775cfc 100644 --- a/configs/experiment/ml/linear_classifier_test_adamw.yaml +++ b/configs/experiment/ml/linear_classifier_test_adamw.yaml @@ -19,3 +19,13 @@ checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/la # before the first test batch. final_embedding_tiles defaults to 4; override here. data: num_workers: 0 + +trainer: + callbacks: + tiff_prediction_maps: + _target_: ml.callbacks.TiffPredictionMapWriter + slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet + artifact_path: prediction_maps_tiff + draw_region: central_stride + slide_selection: all + max_slides: null diff --git a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml index 0001eadc..f7c5ae41 100644 --- a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml +++ b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml @@ -23,3 +23,13 @@ checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/la data: batch_size: 1024 num_workers: 0 + +trainer: + callbacks: + tiff_prediction_maps: + _target_: ml.callbacks.TiffPredictionMapWriter + slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet + artifact_path: prediction_maps_tiff + draw_region: central_stride + slide_selection: all + max_slides: null From 8b3a82d772e9a725981569340f90658c7f08c3b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 12:55:28 +0200 Subject: [PATCH 094/107] chore: deduplicate, apply safety nets --- ...inear_classifier_predict_tissue_tiles.yaml | 3 + ml/__main__.py | 1 - ml/callbacks/tiff_prediction_map_writer.py | 46 +++-- ml/data/data_module.py | 6 +- ml/data/datasets/embedding_tiles.py | 178 +++++++----------- 5 files changed, 109 insertions(+), 125 deletions(-) diff --git a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml index 56845fa3..734aa8dd 100644 --- a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml +++ b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml @@ -21,6 +21,9 @@ test_embedding_uri: runs:/${tissue_embedding_run_id}/test/tiles test_metadata_uri: runs:/${tissue_stats_run_id}/${tissue_stats_artifact_path}/test_tiles.parquet data: + # num_workers MUST be 0 for this config. UnlabeledEmbeddingTilesDataset and + # EmbeddingTilesDataset load the entire split into memory in __init__; using + # worker processes (num_workers > 0) can deadlock after multiprocessing/fork. num_workers: 0 predict: _target_: ml.data.datasets.UnlabeledEmbeddingTilesDataset diff --git a/ml/__main__.py b/ml/__main__.py index 7cf951cd..41be405f 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -33,7 +33,6 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: model, datamodule=data, ckpt_path=_resolve_checkpoint(config.checkpoint), - weights_only=False, ) mlflow.end_run() diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index 1a5924e9..29beebc8 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -72,9 +72,9 @@ def on_test_batch_end( batch_idx: int, dataloader_idx: int = 0, ) -> None: - if trainer.global_rank == 0 and isinstance(outputs, Mapping): + if isinstance(outputs, Mapping): self._batches.append(_to_cpu_batch(outputs)) - if batch_idx % 50 == 0: + if trainer.global_rank == 0 and batch_idx % 50 == 0: print( f"[TiffPredictionMapWriter] test batch {batch_idx} " f"({len(self._batches)} buffered)", @@ -90,7 +90,7 @@ def on_predict_batch_end( batch_idx: int, dataloader_idx: int = 0, ) -> None: - if trainer.global_rank == 0 and outputs is not None: + if outputs is not None: self._batches.append(_to_cpu_batch(outputs)) def on_test_epoch_end( @@ -104,12 +104,14 @@ def on_predict_epoch_end( self._write_maps(trainer) def _write_maps(self, trainer: pl.Trainer) -> None: - if trainer.global_rank != 0 or not self._batches: - self._batches.clear() + batches = _gather_batches(self._batches) + self._batches.clear() + if trainer.global_rank != 0: + return + if not batches: return - predictions = _batches_to_dataframe(self._batches) - self._batches.clear() + predictions = _batches_to_dataframe(batches) if predictions.empty: return @@ -224,6 +226,7 @@ def _write_slide_maps( stride=stride, path=pred_path, mpp=mpp, + background_value=self.background_value, ) print( f"[TiffPredictionMapWriter] wrote pred/{filename}", @@ -296,18 +299,19 @@ def update(self, probs: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor) -> Non self._acc[:, y0:y1, x0:x1] += prob[:, None, None] self._count[y0:y1, x0:x1] += 1 - def labels(self) -> np.ndarray: + def labels(self, background_value: int = 0) -> np.ndarray: """Encode prediction labels like ``remap_annotation_masks``. Class ``i`` (0-based) maps to - ``round(255 * (i + 1) / n_classes)``. Never-covered pixels map to ``0``. + ``round(255 * (i + 1) / n_classes)``. Never-covered pixels map to + ``background_value``. The reporting tool expects GT and prediction masks in the same evenly-spread value space, so this must mirror that LUT exactly. """ idx = self._acc.argmax(0).numpy() out = _spread_lut(self._n_classes)[idx].astype(np.uint8) - out[self._count.numpy() == 0] = 0 + out[self._count.numpy() == 0] = _uint8_scalar(background_value) return np.ascontiguousarray(out) @@ -351,6 +355,7 @@ def _write_class_map( stride: tuple[int, int], path: Path, mpp: tuple[float, float], + background_value: int = 0, ) -> None: assembler = _ClassVoteAssembler( extent[0], @@ -367,11 +372,11 @@ def _write_class_map( torch.from_numpy(ys), ) _emit_mask( - assembler.labels(), + assembler.labels(background_value), assembler.cdx, assembler.cdy, extent, - 0, + int(_uint8_scalar(background_value)), path, mpp, ) @@ -437,6 +442,17 @@ def _write_per_class_probability_maps( ) +def _gather_batches(batches: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return batches + + gathered: list[list[dict[str, Any]]] = [ + [] for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather_object(gathered, batches) + return [batch for rank_batches in gathered for batch in rank_batches] + + def _to_cpu_batch(batch: Mapping[str, Any]) -> dict[str, Any]: return { key: value.detach().cpu() if isinstance(value, torch.Tensor) else value @@ -479,9 +495,13 @@ def _spread_lut(n_classes: int) -> np.ndarray: Mirrors ``preprocessing/remap_annotation_masks.py`` exactly so prediction masks share the GT value space the reporting tool expects: class ``i`` -> ``round(255 * (i + 1) / n_classes)``. Index 0 of the - returned array is unused by callers (background is set separately to 0). + returned array is unused by callers (background is set separately). """ return np.array( [round(255 * (i + 1) / n_classes) for i in range(n_classes)], dtype=np.uint8, ) + + +def _uint8_scalar(value: int) -> int: + return np.asarray(value).clip(0, np.iinfo(np.uint8).max).astype(np.uint8).item() diff --git a/ml/data/data_module.py b/ml/data/data_module.py index 43566b98..3d671eb4 100644 --- a/ml/data/data_module.py +++ b/ml/data/data_module.py @@ -42,7 +42,11 @@ def setup(self, stage: str) -> None: else None ) case "validate": - self.val = instantiate(self.datasets["val"]) + self.val = ( + instantiate(self.datasets["val"]) + if self.datasets.get("val") is not None + else None + ) case "test": self.test = instantiate(self.datasets["test"]) case "predict": diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 6a8a2a0f..2b42ddfa 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -5,6 +5,7 @@ load time to produce ``(embedding, class_index, slide_id, x, y)`` samples. """ +from collections.abc import Callable from functools import cache from pathlib import Path from time import perf_counter @@ -62,63 +63,9 @@ def _diag(msg: str) -> None: ) _diag(f"metadata filtered: {len(meta_df)} rows; reading embeddings") - emb_dir = self._resolve_uri(embedding_uri) - emb_table = pads.dataset(emb_dir, format="parquet").to_table( - columns=["slide_id", "x", "y", "embedding"] + joined_keys, embeddings = _load_embeddings_and_join( + embedding_uri, meta_df, _diag ) - _diag(f"embedding table loaded: {emb_table.num_rows} rows") - - emb_col = emb_table.column("embedding") - if pa.types.is_list(emb_col.type): - target_type = pa.large_list(emb_col.type.value_type) - emb_col = pa.chunked_array( - [c.cast(target_type) for c in emb_col.chunks], type=target_type - ) - - emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) - emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) - del emb_table - - meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) - del meta_df - joined_keys = meta_table.join( - emb_keys, keys=["slide_id", "x", "y"], join_type="inner" - ) - del emb_keys, meta_table - if joined_keys.num_rows == 0: - raise RuntimeError("inner join with embeddings produced empty dataset") - _diag(f"join done: {joined_keys.num_rows} matched rows; filling embeddings") - - _idx_col = joined_keys.column("_emb_idx") - if isinstance(_idx_col, pa.ChunkedArray): - _idx_col = _idx_col.combine_chunks() - indices_np = _idx_col.to_numpy() - - first_chunk = emb_col.chunks[0] - embedding_dim = len(first_chunk.values) // len(first_chunk) - - # sort indices for sequential per-chunk access; restore order afterwards - sort_order = np.argsort(indices_np) - sorted_indices = indices_np[sort_order] - - chunk_offsets = np.concatenate( - [[0], np.cumsum([len(c) for c in emb_col.chunks])] - ) - embeddings = np.empty((len(indices_np), embedding_dim), dtype=np.float32) - for ci, chunk in enumerate(emb_col.chunks): - lo, hi = chunk_offsets[ci], chunk_offsets[ci + 1] - mask = (sorted_indices >= lo) & (sorted_indices < hi) - if not mask.any(): - continue - local_idx = sorted_indices[mask] - lo - chunk_np = ( - chunk.values.to_numpy(zero_copy_only=False) - .reshape(len(chunk), embedding_dim) - .astype(np.float32) - ) - embeddings[sort_order[mask]] = chunk_np[local_idx] - del emb_col - self.embeddings = embeddings labels = joined_keys.column("label").to_pandas() unknown = set(labels.unique()) - set(class_indices.keys()) @@ -215,6 +162,69 @@ def _resolve_uri_cached(uri: str) -> str: return uri +def _load_embeddings_and_join( + embedding_uri: str | Path, + meta_df: pd.DataFrame, + diag: Callable[[str], None], +) -> tuple[pa.Table, np.ndarray]: + emb_dir = EmbeddingTilesDataset._resolve_uri(embedding_uri) + emb_table = pads.dataset(emb_dir, format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) + diag(f"embedding table loaded: {emb_table.num_rows} rows") + + emb_col = emb_table.column("embedding") + if pa.types.is_list(emb_col.type): + target_type = pa.large_list(emb_col.type.value_type) + emb_col = pa.chunked_array( + [c.cast(target_type) for c in emb_col.chunks], type=target_type + ) + + emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) + emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) + del emb_table + + meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) + del meta_df + joined_keys = meta_table.join( + emb_keys, keys=["slide_id", "x", "y"], join_type="inner" + ) + del emb_keys, meta_table + if joined_keys.num_rows == 0: + raise RuntimeError("inner join with embeddings produced empty dataset") + diag(f"join done: {joined_keys.num_rows} matched rows; filling embeddings") + + _idx_col = joined_keys.column("_emb_idx") + if isinstance(_idx_col, pa.ChunkedArray): + _idx_col = _idx_col.combine_chunks() + indices_np = _idx_col.to_numpy() + + first_chunk = emb_col.chunks[0] + embedding_dim = len(first_chunk.values) // len(first_chunk) + + # Sort indices for sequential per-chunk access; restore order afterwards. + sort_order = np.argsort(indices_np) + sorted_indices = indices_np[sort_order] + + chunk_offsets = np.concatenate([[0], np.cumsum([len(c) for c in emb_col.chunks])]) + embeddings = np.empty((len(indices_np), embedding_dim), dtype=np.float32) + for ci, chunk in enumerate(emb_col.chunks): + lo, hi = chunk_offsets[ci], chunk_offsets[ci + 1] + mask = (sorted_indices >= lo) & (sorted_indices < hi) + if not mask.any(): + continue + local_idx = sorted_indices[mask] - lo + chunk_np = ( + chunk.values.to_numpy(zero_copy_only=False) + .reshape(len(chunk), embedding_dim) + .astype(np.float32) + ) + embeddings[sort_order[mask]] = chunk_np[local_idx] + del emb_col + + return joined_keys, embeddings + + class UnlabeledEmbeddingTilesDataset(Dataset[Sample]): """Tile-embedding dataset for prediction over tiles without class labels.""" @@ -238,61 +248,9 @@ def _diag(msg: str) -> None: meta_df = self._filter_metadata(metadata_uri, tissue_column, tissue_min) _diag(f"metadata filtered: {len(meta_df)} rows; reading embeddings") - emb_dir = EmbeddingTilesDataset._resolve_uri(embedding_uri) - emb_table = pads.dataset(emb_dir, format="parquet").to_table( - columns=["slide_id", "x", "y", "embedding"] + joined_keys, embeddings = _load_embeddings_and_join( + embedding_uri, meta_df, _diag ) - _diag(f"embedding table loaded: {emb_table.num_rows} rows") - - emb_col = emb_table.column("embedding") - if pa.types.is_list(emb_col.type): - target_type = pa.large_list(emb_col.type.value_type) - emb_col = pa.chunked_array( - [c.cast(target_type) for c in emb_col.chunks], type=target_type - ) - - emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) - emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) - del emb_table - - meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) - del meta_df - joined_keys = meta_table.join( - emb_keys, keys=["slide_id", "x", "y"], join_type="inner" - ) - del emb_keys, meta_table - if joined_keys.num_rows == 0: - raise RuntimeError("inner join with embeddings produced empty dataset") - _diag(f"join done: {joined_keys.num_rows} matched rows; filling embeddings") - - _idx_col = joined_keys.column("_emb_idx") - if isinstance(_idx_col, pa.ChunkedArray): - _idx_col = _idx_col.combine_chunks() - indices_np = _idx_col.to_numpy() - - first_chunk = emb_col.chunks[0] - embedding_dim = len(first_chunk.values) // len(first_chunk) - sort_order = np.argsort(indices_np) - sorted_indices = indices_np[sort_order] - - chunk_offsets = np.concatenate( - [[0], np.cumsum([len(c) for c in emb_col.chunks])] - ) - embeddings = np.empty((len(indices_np), embedding_dim), dtype=np.float32) - for ci, chunk in enumerate(emb_col.chunks): - lo, hi = chunk_offsets[ci], chunk_offsets[ci + 1] - mask = (sorted_indices >= lo) & (sorted_indices < hi) - if not mask.any(): - continue - local_idx = sorted_indices[mask] - lo - chunk_np = ( - chunk.values.to_numpy(zero_copy_only=False) - .reshape(len(chunk), embedding_dim) - .astype(np.float32) - ) - embeddings[sort_order[mask]] = chunk_np[local_idx] - del emb_col - self.embeddings = embeddings self.labels = np.full(len(joined_keys), label_value, dtype=np.int64) self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() From e16426eb93b3787d3fa7506197632d90b95c32d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 18:06:52 +0200 Subject: [PATCH 095/107] fix: pytorch checkpoint loading --- .../ml/linear_classifier_predict_tissue_tiles.yaml | 1 + .../experiment/ml/linear_classifier_test_adamw.yaml | 1 + .../experiment/ml/linear_classifier_test_lbfgs.yaml | 1 + configs/ml.yaml | 1 + ml/__main__.py | 12 +++++++----- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml index 734aa8dd..a8d4ad9c 100644 --- a/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml +++ b/configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml @@ -10,6 +10,7 @@ defaults: mode: predict final_train_run_id: 0e2230c722134ce0985e09a18ccadf75 checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt +checkpoint_weights_only: false tissue_embedding_run_id: 95a02c93c164415e94702ad5c83ccca2 tissue_stats_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id} diff --git a/configs/experiment/ml/linear_classifier_test_adamw.yaml b/configs/experiment/ml/linear_classifier_test_adamw.yaml index 6c775cfc..ed214e53 100644 --- a/configs/experiment/ml/linear_classifier_test_adamw.yaml +++ b/configs/experiment/ml/linear_classifier_test_adamw.yaml @@ -11,6 +11,7 @@ defaults: mode: test final_train_run_id: a23e478b00b04da79cfbf4d91cada8cd checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt +checkpoint_weights_only: false # num_workers MUST stay 0. EmbeddingTilesDataset loads the entire split into # one in-memory numpy array in __init__ and __getitem__ is pure numpy indexing diff --git a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml index f7c5ae41..2dabef8a 100644 --- a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml +++ b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml @@ -19,6 +19,7 @@ defaults: mode: test final_train_run_id: 0e2230c722134ce0985e09a18ccadf75 checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt +checkpoint_weights_only: false data: batch_size: 1024 diff --git a/configs/ml.yaml b/configs/ml.yaml index bfde61e2..c34ff6df 100644 --- a/configs/ml.yaml +++ b/configs/ml.yaml @@ -6,6 +6,7 @@ defaults: seed: ${random_seed:} mode: ??? checkpoint: null +checkpoint_weights_only: null trainer: {} diff --git a/ml/__main__.py b/ml/__main__.py index 41be405f..b474b9ee 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -29,11 +29,13 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: allowed_modes = ["fit", "test", "validate", "predict"] if config.mode not in allowed_modes: raise ValueError(f"Invalid mode {config.mode!r}. Allowed: {allowed_modes}") - getattr(trainer, config.mode)( - model, - datamodule=data, - ckpt_path=_resolve_checkpoint(config.checkpoint), - ) + run_kwargs = { + "datamodule": data, + "ckpt_path": _resolve_checkpoint(config.checkpoint), + } + if config.checkpoint_weights_only is not None: + run_kwargs["weights_only"] = config.checkpoint_weights_only + getattr(trainer, config.mode)(model, **run_kwargs) mlflow.end_run() From fd3fdd6dc14a2283b3bbe83d3939c4cde37cc4cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 18:32:58 +0200 Subject: [PATCH 096/107] chore: remove redundancy, rename variables --- .../ml/linear_classifier_final_lbfgs.yaml | 4 +- ...assifier_lbfgs_stratified_group_kfold.yaml | 2 +- ...ear_classifier_lbfgs_stratified_kfold.yaml | 2 +- .../ml/linear_classifier_test_lbfgs.yaml | 8 +- configs/ml/data/final_embedding_tiles.yaml | 2 +- configs/ml/data/kfold_embedding_tiles.yaml | 2 +- configs/ml/task/final_linear_classifier.yaml | 2 +- configs/ml/task/kfold_linear_classifier.yaml | 2 +- ml/data/data_module.py | 8 +- ml/data/datasets/embedding_tiles.py | 151 +++++++++--------- ml/meta_arch.py | 14 +- 11 files changed, 94 insertions(+), 103 deletions(-) diff --git a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml index 67174b0d..5169d83c 100644 --- a/configs/experiment/ml/linear_classifier_final_lbfgs.yaml +++ b/configs/experiment/ml/linear_classifier_final_lbfgs.yaml @@ -6,13 +6,13 @@ defaults: # LBFGS final: exact solve of the convex objective on the full training batch. # weight_decay=1e-2 = best LBFGS sweep point. Full-batch guard requires -# train_shuffle=false, train_drop_last=false, batch_size >= len(train); +# train_shuffle=false, train_drop_last=false, train_batch_size >= len(train); # num_workers=0 avoids the single-batch IPC deadlock. trainer: max_epochs: 10 data: - batch_size: 1000000000 + train_batch_size: 1000000000 eval_batch_size: 1024 train_shuffle: false train_drop_last: false diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml index 4d92561f..3f11835d 100644 --- a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml @@ -8,7 +8,7 @@ trainer: max_epochs: 10 data: - batch_size: 1000000000 + train_batch_size: 1000000000 train_shuffle: false train_drop_last: false num_workers: 0 diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml index bd3c10b3..f857ccd4 100644 --- a/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml @@ -8,7 +8,7 @@ trainer: max_epochs: 10 data: - batch_size: 1000000000 + train_batch_size: 1000000000 train_shuffle: false train_drop_last: false diff --git a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml index 2dabef8a..c005c373 100644 --- a/configs/experiment/ml/linear_classifier_test_lbfgs.yaml +++ b/configs/experiment/ml/linear_classifier_test_lbfgs.yaml @@ -6,9 +6,9 @@ defaults: - _self_ # Test the LBFGS final checkpoint on the held-out test split. The full-batch -# batch_size=1e9 is a TRAINING requirement for the convex LBFGS solve only; at -# test there is no optimization, so revert to a normal batch to avoid loading -# the whole test set as one tensor (OOM). +# train_batch_size=1e9 is a TRAINING requirement for the convex LBFGS solve +# only; at test there is no optimization, so use a normal batch to avoid +# loading the whole test set as one tensor (OOM). # # num_workers MUST stay 0. EmbeddingTilesDataset loads the entire split into # one in-memory numpy array in __init__ and __getitem__ is pure numpy indexing @@ -22,7 +22,7 @@ checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/la checkpoint_weights_only: false data: - batch_size: 1024 + train_batch_size: 1024 num_workers: 0 trainer: diff --git a/configs/ml/data/final_embedding_tiles.yaml b/configs/ml/data/final_embedding_tiles.yaml index 56355946..826f7a27 100644 --- a/configs/ml/data/final_embedding_tiles.yaml +++ b/configs/ml/data/final_embedding_tiles.yaml @@ -1,7 +1,7 @@ # @package _global_ data: - batch_size: 1024 + train_batch_size: 1024 num_workers: 4 train_shuffle: true train_drop_last: false diff --git a/configs/ml/data/kfold_embedding_tiles.yaml b/configs/ml/data/kfold_embedding_tiles.yaml index 40ff4b71..e1248a80 100644 --- a/configs/ml/data/kfold_embedding_tiles.yaml +++ b/configs/ml/data/kfold_embedding_tiles.yaml @@ -1,7 +1,7 @@ # @package _global_ data: - batch_size: 1024 + train_batch_size: 1024 num_workers: 4 train_shuffle: true train_drop_last: true diff --git a/configs/ml/task/final_linear_classifier.yaml b/configs/ml/task/final_linear_classifier.yaml index 154cec5e..61afef74 100644 --- a/configs/ml/task/final_linear_classifier.yaml +++ b/configs/ml/task/final_linear_classifier.yaml @@ -44,5 +44,5 @@ metadata: thresholds: ${thresholds} learning_rate: ${model.learning_rate} weight_decay: ${model.weight_decay} - batch_size: ${data.batch_size} + batch_size: ${data.train_batch_size} max_epochs: ${trainer.max_epochs} diff --git a/configs/ml/task/kfold_linear_classifier.yaml b/configs/ml/task/kfold_linear_classifier.yaml index 859ce685..aa733417 100644 --- a/configs/ml/task/kfold_linear_classifier.yaml +++ b/configs/ml/task/kfold_linear_classifier.yaml @@ -49,6 +49,6 @@ metadata: learning_rate: ${model.learning_rate} weight_decay: ${model.weight_decay} lbfgs: ${model.lbfgs} - batch_size: ${data.batch_size} + batch_size: ${data.train_batch_size} train_shuffle: ${data.train_shuffle} train_drop_last: ${data.train_drop_last} diff --git a/ml/data/data_module.py b/ml/data/data_module.py index 3d671eb4..b96df1e9 100644 --- a/ml/data/data_module.py +++ b/ml/data/data_module.py @@ -17,7 +17,7 @@ class DataModule(LightningDataModule): def __init__( self, - batch_size: int, + train_batch_size: int, eval_batch_size: int | None = None, num_workers: int = 0, train_shuffle: bool = True, @@ -25,8 +25,8 @@ def __init__( **datasets: DictConfig, ) -> None: super().__init__() - self.batch_size = batch_size - self.eval_batch_size = eval_batch_size or batch_size + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size or train_batch_size self.num_workers = num_workers self.train_shuffle = train_shuffle self.train_drop_last = train_drop_last @@ -58,7 +58,7 @@ def setup(self, stage: str) -> None: def train_dataloader(self) -> Iterable[Input]: return DataLoader( self.train, - batch_size=self.batch_size, + batch_size=self.train_batch_size, shuffle=self.train_shuffle, drop_last=self.train_drop_last, num_workers=self.num_workers, diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 2b42ddfa..ba00eb0a 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -21,7 +21,41 @@ from ml.typing import Sample -class EmbeddingTilesDataset(Dataset[Sample]): +class _BaseEmbeddingTilesDataset(Dataset[Sample]): + def __init__( + self, + embedding_uri: str | Path, + meta_df: pd.DataFrame, + diag: Callable[[str], None], + ) -> None: + diag(f"metadata filtered: {len(meta_df)} rows; reading embeddings") + joined_keys, embeddings = _load_embeddings_and_join( + embedding_uri, meta_df, diag + ) + self.embeddings = embeddings + self.labels = self._labels_from_joined_keys(joined_keys) + self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() + self.xs = joined_keys.column("x").to_pandas().to_numpy(dtype=np.int64) + self.ys = joined_keys.column("y").to_pandas().to_numpy(dtype=np.int64) + diag(f"dataset ready: {len(self.labels)} samples, dim={embeddings.shape[1]}") + + def __len__(self) -> int: + return len(self.labels) + + def __getitem__(self, idx: int) -> Sample: + return ( + torch.from_numpy(self.embeddings[idx]), + int(self.labels[idx]), + str(self.slide_ids[idx]), + int(self.xs[idx]), + int(self.ys[idx]), + ) + + def _labels_from_joined_keys(self, joined_keys: pa.Table) -> np.ndarray: + raise NotImplementedError + + +class EmbeddingTilesDataset(_BaseEmbeddingTilesDataset): """Tile-level embedding dataset with on-the-fly filtering and labeling. Inner-joins ``embedding`` parquet with ``metadata`` parquet on @@ -45,15 +79,9 @@ def __init__( include_folds: list[int] | None = None, exclude_folds: list[int] | None = None, ) -> None: - t0 = perf_counter() - - def _diag(msg: str) -> None: - print( - f"[EmbeddingTilesDataset +{perf_counter() - t0:6.1f}s] {msg}", - flush=True, - ) - - _diag("filtering metadata") + self.class_indices = class_indices + diag = _make_diag(type(self).__name__) + diag("filtering metadata") meta_df = self._filter_metadata( metadata_uri, thresholds, @@ -61,35 +89,16 @@ def _diag(msg: str) -> None: include_folds, exclude_folds, ) - _diag(f"metadata filtered: {len(meta_df)} rows; reading embeddings") + super().__init__(embedding_uri, meta_df, diag) - joined_keys, embeddings = _load_embeddings_and_join( - embedding_uri, meta_df, _diag - ) - self.embeddings = embeddings + def _labels_from_joined_keys(self, joined_keys: pa.Table) -> np.ndarray: labels = joined_keys.column("label").to_pandas() - unknown = set(labels.unique()) - set(class_indices.keys()) + unknown = set(labels.unique()) - set(self.class_indices.keys()) if unknown: raise ValueError( f"labels in data not present in class_indices: {sorted(unknown)}" ) - self.labels = labels.map(class_indices).to_numpy(dtype=np.int64) - self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() - self.xs = joined_keys.column("x").to_pandas().to_numpy(dtype=np.int64) - self.ys = joined_keys.column("y").to_pandas().to_numpy(dtype=np.int64) - _diag(f"dataset ready: {len(self.labels)} samples, dim={embeddings.shape[1]}") - - def __len__(self) -> int: - return len(self.labels) - - def __getitem__(self, idx: int) -> Sample: - return ( - torch.from_numpy(self.embeddings[idx]), - int(self.labels[idx]), - str(self.slide_ids[idx]), - int(self.xs[idx]), - int(self.ys[idx]), - ) + return labels.map(self.class_indices).to_numpy(dtype=np.int64) @staticmethod def _filter_metadata( @@ -99,7 +108,7 @@ def _filter_metadata( include_folds: list[int] | None, exclude_folds: list[int] | None, ) -> pd.DataFrame: - local = EmbeddingTilesDataset._resolve_uri(metadata_uri) + local = _resolve_uri(metadata_uri) df = pd.read_parquet(local) roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] @@ -150,24 +159,13 @@ def _filter_metadata( return df[["slide_id", "x", "y", "label"]] - @staticmethod - def _resolve_uri(path_or_uri: str | Path) -> str: - return EmbeddingTilesDataset._resolve_uri_cached(str(path_or_uri)) - - @staticmethod - @cache - def _resolve_uri_cached(uri: str) -> str: - if uri.startswith(("mlflow-artifacts:/", "runs:/")): - return download_artifacts(artifact_uri=uri) - return uri - def _load_embeddings_and_join( embedding_uri: str | Path, meta_df: pd.DataFrame, diag: Callable[[str], None], ) -> tuple[pa.Table, np.ndarray]: - emb_dir = EmbeddingTilesDataset._resolve_uri(embedding_uri) + emb_dir = _resolve_uri(embedding_uri) emb_table = pads.dataset(emb_dir, format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) @@ -225,7 +223,7 @@ def _load_embeddings_and_join( return joined_keys, embeddings -class UnlabeledEmbeddingTilesDataset(Dataset[Sample]): +class UnlabeledEmbeddingTilesDataset(_BaseEmbeddingTilesDataset): """Tile-embedding dataset for prediction over tiles without class labels.""" def __init__( @@ -236,39 +234,14 @@ def __init__( tissue_min: float = 0.0, label_value: int = -1, ) -> None: - t0 = perf_counter() - - def _diag(msg: str) -> None: - print( - f"[UnlabeledEmbeddingTilesDataset +{perf_counter() - t0:6.1f}s] {msg}", - flush=True, - ) - - _diag("filtering metadata") + self.label_value = label_value + diag = _make_diag(type(self).__name__) + diag("filtering metadata") meta_df = self._filter_metadata(metadata_uri, tissue_column, tissue_min) - _diag(f"metadata filtered: {len(meta_df)} rows; reading embeddings") + super().__init__(embedding_uri, meta_df, diag) - joined_keys, embeddings = _load_embeddings_and_join( - embedding_uri, meta_df, _diag - ) - self.embeddings = embeddings - self.labels = np.full(len(joined_keys), label_value, dtype=np.int64) - self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() - self.xs = joined_keys.column("x").to_pandas().to_numpy(dtype=np.int64) - self.ys = joined_keys.column("y").to_pandas().to_numpy(dtype=np.int64) - _diag(f"dataset ready: {len(self.labels)} samples, dim={embeddings.shape[1]}") - - def __len__(self) -> int: - return len(self.labels) - - def __getitem__(self, idx: int) -> Sample: - return ( - torch.from_numpy(self.embeddings[idx]), - int(self.labels[idx]), - str(self.slide_ids[idx]), - int(self.xs[idx]), - int(self.ys[idx]), - ) + def _labels_from_joined_keys(self, joined_keys: pa.Table) -> np.ndarray: + return np.full(joined_keys.num_rows, self.label_value, dtype=np.int64) @staticmethod def _filter_metadata( @@ -276,7 +249,7 @@ def _filter_metadata( tissue_column: str, tissue_min: float, ) -> pd.DataFrame: - local = EmbeddingTilesDataset._resolve_uri(metadata_uri) + local = _resolve_uri(metadata_uri) columns = ["slide_id", "x", "y", tissue_column] df = pd.read_parquet(local, columns=columns) if tissue_column not in df.columns: @@ -289,3 +262,23 @@ def _filter_metadata( f"all tiles dropped by {tissue_column} > {tissue_min} filter" ) return df + + +def _resolve_uri(path_or_uri: str | Path) -> str: + return _resolve_uri_cached(str(path_or_uri)) + + +def _make_diag(dataset_name: str) -> Callable[[str], None]: + t0 = perf_counter() + + def _diag(msg: str) -> None: + print(f"[{dataset_name} +{perf_counter() - t0:6.1f}s] {msg}", flush=True) + + return _diag + + +@cache +def _resolve_uri_cached(uri: str) -> str: + if uri.startswith(("mlflow-artifacts:/", "runs:/")): + return download_artifacts(artifact_uri=uri) + return uri diff --git a/ml/meta_arch.py b/ml/meta_arch.py index a20cf4ce..2146d9b8 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -56,7 +56,7 @@ def __init__( n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) ] num_classes = len(self.class_names) - self.criterion = nn.CrossEntropyLoss(weight=torch.ones(num_classes)) + self.criterion = nn.CrossEntropyLoss() macro_metrics = MetricCollection( { @@ -102,9 +102,7 @@ def setup(self, stage: str) -> None: for cls, w in zip(self.class_names, weights.tolist(), strict=True): mlflow.log_metric(f"class_weight/{cls}", w) else: - self.criterion = nn.CrossEntropyLoss( - weight=torch.ones(len(self.class_names), dtype=torch.float32) - ) + self.criterion = nn.CrossEntropyLoss() def forward(self, x: Tensor) -> Outputs: features = self.backbone(x) @@ -274,9 +272,9 @@ def _prepare_lbfgs_batch( def _validate_lbfgs_full_batch(self, datamodule: Any, train_size: int) -> None: lbfgs = self.hparams.get("lbfgs") or {} - batch_size = int(datamodule.batch_size) + train_batch_size = int(datamodule.train_batch_size) accumulation_steps = int(lbfgs.get("accumulate_batches", 1)) - effective_batch_size = batch_size * accumulation_steps + effective_batch_size = train_batch_size * accumulation_steps if datamodule.train_shuffle: raise ValueError("LBFGS requires data.train_shuffle=false.") @@ -285,9 +283,9 @@ def _validate_lbfgs_full_batch(self, datamodule: Any, train_size: int) -> None: if effective_batch_size < train_size: raise ValueError( "LBFGS requires a deterministic full-batch objective. Set " - "data.batch_size >= len(train) or set " + "data.train_batch_size >= len(train) or set " "model.lbfgs.accumulate_batches >= ceil(len(train) / " - "data.batch_size). Current effective batch size is " + "data.train_batch_size). Current effective batch size is " f"{effective_batch_size} for {train_size} training samples." ) From c401015325b7f7609876230a06b9d18ae3255d35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 18:33:49 +0200 Subject: [PATCH 097/107] chore: remove username and branch --- scripts/submit_train_linear_probe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/submit_train_linear_probe.py b/scripts/submit_train_linear_probe.py index 3f7ecb1a..a5669c80 100644 --- a/scripts/submit_train_linear_probe.py +++ b/scripts/submit_train_linear_probe.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-train-linear", - username="vcifka", + username=..., cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", + "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=ml/linear_classifier_stratified_group_kfold val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", + "uv run python -m ml +experiment=m... val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", ], storage=[storage.secure.PROJECTS], ) From 847c3ccdce916155cca363c31401427acaab6c8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 18:35:52 +0200 Subject: [PATCH 098/107] refactor: rename configs --- ...yaml => linear_classifier_adamw_stratified_group_kfold.yaml} | 0 ...kfold.yaml => linear_classifier_adamw_stratified_kfold.yaml} | 0 scripts/submit_train_linear_probe.py | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename configs/experiment/ml/{linear_classifier_stratified_group_kfold.yaml => linear_classifier_adamw_stratified_group_kfold.yaml} (100%) rename configs/experiment/ml/{linear_classifier_stratified_kfold.yaml => linear_classifier_adamw_stratified_kfold.yaml} (100%) diff --git a/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_adamw_stratified_group_kfold.yaml similarity index 100% rename from configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml rename to configs/experiment/ml/linear_classifier_adamw_stratified_group_kfold.yaml diff --git a/configs/experiment/ml/linear_classifier_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_adamw_stratified_kfold.yaml similarity index 100% rename from configs/experiment/ml/linear_classifier_stratified_kfold.yaml rename to configs/experiment/ml/linear_classifier_adamw_stratified_kfold.yaml diff --git a/scripts/submit_train_linear_probe.py b/scripts/submit_train_linear_probe.py index a5669c80..dcfc7d93 100644 --- a/scripts/submit_train_linear_probe.py +++ b/scripts/submit_train_linear_probe.py @@ -12,7 +12,7 @@ "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=m... val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", + "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", ], storage=[storage.secure.PROJECTS], ) From 2ba05625eb52f94ef9676823b631d94fabf8b9ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 20:12:59 +0200 Subject: [PATCH 099/107] fix: keep criterion.weight in state_dict for strict checkpoint load setup(stage="fit") replaces criterion with class-weighted CrossEntropyLoss, adding a criterion.weight buffer that gets saved to checkpoints. At test, Lightning restores the checkpoint before setup() runs, so the model still has the unweighted criterion from __init__ and strict load fails with "Unexpected key(s) in state_dict: criterion.weight". Affected both adamw and lbfgs test runs. Initialize criterion with a placeholder ones-weight sized num_classes so the criterion.weight key always exists; setup(fit) still overrides it with the real class-balanced weights. Co-Authored-By: Claude Opus 4.7 --- ml/meta_arch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 2146d9b8..e5159007 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -56,7 +56,10 @@ def __init__( n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) ] num_classes = len(self.class_names) - self.criterion = nn.CrossEntropyLoss() + # Placeholder weight so `criterion.weight` always exists in state_dict. + # setup(stage="fit") overrides with class-balanced weights; checkpoints + # then load with strict=True regardless of stage. + self.criterion = nn.CrossEntropyLoss(weight=torch.ones(num_classes)) macro_metrics = MetricCollection( { From e3704173eb480e5d60788e79f9fa68c6116610bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 21:15:43 +0200 Subject: [PATCH 100/107] fix: criterion weight --- ml/meta_arch.py | 4 ++-- scripts/submit_test_linear.py | 8 ++++---- submit_report.py | 21 +++++++++++++++++++++ 3 files changed, 27 insertions(+), 6 deletions(-) create mode 100644 submit_report.py diff --git a/ml/meta_arch.py b/ml/meta_arch.py index e5159007..c76bb151 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -104,8 +104,8 @@ def setup(self, stage: str) -> None: ) for cls, w in zip(self.class_names, weights.tolist(), strict=True): mlflow.log_metric(f"class_weight/{cls}", w) - else: - self.criterion = nn.CrossEntropyLoss() + # Non-fit stages keep the placeholder ones-weight criterion from + # __init__ so `criterion.weight` stays in state_dict for strict load. def forward(self, x: Tensor) -> Outputs: features = self.backbone(x) diff --git a/scripts/submit_test_linear.py b/scripts/submit_test_linear.py index 3eccdea7..dec0968b 100644 --- a/scripts/submit_test_linear.py +++ b/scripts/submit_test_linear.py @@ -2,17 +2,17 @@ submit_job( - job_name="tissue-classification-test-linear-final", - username=..., + job_name="tissue-classification-test-linear-final-adamw", + username="vcifka", cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", + "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=...", + "uv run python -m ml +experiment=ml/linear_classifier_test_adamw", ], storage=[storage.secure.PROJECTS], ) diff --git a/submit_report.py b/submit_report.py new file mode 100644 index 00000000..e543fff5 --- /dev/null +++ b/submit_report.py @@ -0,0 +1,21 @@ +from kube_jobs import storage, submit_job + +config_dir = "/mnt/projects/tissue_classification/conf/reporter" +config_name = "tissue_classification_lbfgs_mug" + +submit_job( + job_name=f"tissue-classification-report-{config_name.replace('_', '-')}", + username="vcifka", + cpu=8, + memory="16Gi", + gpu=None, + public=False, + script=[ + "git clone https://gitlab.ics.muni.cz/rationai/digital-pathology/pathology/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv pip install git+ssh://git@gitlab.ics.muni.cz/rationai/digital-pathology/pipeline/report.git@feature/force-wsi-service-protocol", + f"uv run python -m report --config-dir {config_dir} --config-name={config_name} user=vcifka" + ], + storage=[storage.secure.DATA, storage.secure.PROJECTS], +) \ No newline at end of file From 632a8f6fab146606a3ac74263a1ed47d77ca5428 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 21:56:01 +0200 Subject: [PATCH 101/107] fix: keep space in MUG prediction masks names --- ml/callbacks/tiff_prediction_map_writer.py | 2 +- submit_report.py | 21 --------------------- 2 files changed, 1 insertion(+), 22 deletions(-) delete mode 100644 submit_report.py diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index 29beebc8..c9e82ad6 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -486,7 +486,7 @@ def _resolve_uri(uri: str) -> str: def _safe_filename(value: str) -> str: - return sub(r"[^A-Za-z0-9_.-]+", "_", value) + return sub(r"[^A-Za-z0-9 _.-]+", "_", value) def _spread_lut(n_classes: int) -> np.ndarray: diff --git a/submit_report.py b/submit_report.py deleted file mode 100644 index e543fff5..00000000 --- a/submit_report.py +++ /dev/null @@ -1,21 +0,0 @@ -from kube_jobs import storage, submit_job - -config_dir = "/mnt/projects/tissue_classification/conf/reporter" -config_name = "tissue_classification_lbfgs_mug" - -submit_job( - job_name=f"tissue-classification-report-{config_name.replace('_', '-')}", - username="vcifka", - cpu=8, - memory="16Gi", - gpu=None, - public=False, - script=[ - "git clone https://gitlab.ics.muni.cz/rationai/digital-pathology/pathology/tissue-classification.git workdir", - "cd workdir", - "uv sync", - "uv pip install git+ssh://git@gitlab.ics.muni.cz/rationai/digital-pathology/pipeline/report.git@feature/force-wsi-service-protocol", - f"uv run python -m report --config-dir {config_dir} --config-name={config_name} user=vcifka" - ], - storage=[storage.secure.DATA, storage.secure.PROJECTS], -) \ No newline at end of file From 3cd0243400d5a22a4be0458b69d48f1008b9d1f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 21:57:52 +0200 Subject: [PATCH 102/107] fix: log test accuracy as jsons --- ml/meta_arch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index c76bb151..e1baae33 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -367,7 +367,7 @@ def _log_per_slide_accuracy(self) -> None: ] mlflow.log_table( data=pd.DataFrame(rows), - artifact_file="per_slide/test_tile_accuracy.parquet", + artifact_file="per_slide/test_tile_accuracy.json", ) From 76e41940f82f477c5d8ef3c07f0f1c1517bf013d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 22:01:58 +0200 Subject: [PATCH 103/107] chore: remove username from the submission script --- scripts/submit_test_linear.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/submit_test_linear.py b/scripts/submit_test_linear.py index dec0968b..b35769a2 100644 --- a/scripts/submit_test_linear.py +++ b/scripts/submit_test_linear.py @@ -2,17 +2,17 @@ submit_job( - job_name="tissue-classification-test-linear-final-adamw", + job_name="tissue-classification-test-linear-final", username="vcifka", cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone --branch feature/ml-test-mode https://github.com/RationAI/tissue-classification.git workdir", + "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=ml/linear_classifier_test_adamw", + "uv run python -m ml +experiment=...", ], storage=[storage.secure.PROJECTS], ) From 597e348f7eb319c0d571c67151f56921c253fec0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 22:28:22 +0200 Subject: [PATCH 104/107] fix: force the entering of the write phase of the prediction maps --- ml/callbacks/tiff_prediction_map_writer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index c9e82ad6..c71da816 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -98,11 +98,19 @@ def on_test_epoch_end( ) -> None: self._write_maps(trainer) + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._write_maps(trainer) + def on_predict_epoch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule ) -> None: self._write_maps(trainer) + def on_predict_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + self._write_maps(trainer) + def _write_maps(self, trainer: pl.Trainer) -> None: batches = _gather_batches(self._batches) self._batches.clear() From e4a4cc537b85b79e453b4bd351585e9e53f2cb31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 22:47:02 +0200 Subject: [PATCH 105/107] fix: surface why prediction-map write phase skips Clear the batch buffer only on rank!=0 or after a successful write so the on_test_end fallback no longer hits an always-empty buffer. Add diagnostic prints to the silent early-return guards and an idempotency flag so the two write hooks cooperate. Co-Authored-By: Claude Opus 4.7 --- ml/callbacks/tiff_prediction_map_writer.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index c71da816..d83b24d9 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -53,15 +53,18 @@ def __init__( self.max_slides = max_slides self.slide_selection = slide_selection self._batches: list[dict[str, Any]] = [] + self._written = False def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._batches.clear() + self._written = False print("[TiffPredictionMapWriter] test loop started", flush=True) def on_predict_start( self, trainer: pl.Trainer, pl_module: pl.LightningModule ) -> None: self._batches.clear() + self._written = False def on_test_batch_end( self, @@ -112,15 +115,28 @@ def on_predict_end( self._write_maps(trainer) def _write_maps(self, trainer: pl.Trainer) -> None: + if self._written: + return batches = _gather_batches(self._batches) - self._batches.clear() if trainer.global_rank != 0: + self._batches.clear() return if not batches: + print( + "[TiffPredictionMapWriter] no buffered batches at write time " + f"(local buffer={len(self._batches)}, " + f"dist_init={torch.distributed.is_available() and torch.distributed.is_initialized()}); " + "skipping", + flush=True, + ) return predictions = _batches_to_dataframe(batches) if predictions.empty: + print( + "[TiffPredictionMapWriter] predictions dataframe empty; skipping", + flush=True, + ) return slides = pd.read_parquet(_resolve_uri(self.slides_uri)) @@ -170,6 +186,9 @@ def _write_maps(self, trainer: pl.Trainer) -> None: ) mlflow.log_artifacts(output_dir, artifact_path=self.artifact_path) + self._written = True + self._batches.clear() + def _select_slide_groups( self, predictions: pd.DataFrame ) -> list[tuple[str, pd.DataFrame]]: From 3829ebd6dd978613368b6bf62f843f2d494ff7c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 18 May 2026 22:59:06 +0200 Subject: [PATCH 106/107] fix: remove username --- scripts/submit_test_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/submit_test_linear.py b/scripts/submit_test_linear.py index b35769a2..3eccdea7 100644 --- a/scripts/submit_test_linear.py +++ b/scripts/submit_test_linear.py @@ -3,7 +3,7 @@ submit_job( job_name="tissue-classification-test-linear-final", - username="vcifka", + username=..., cpu=8, memory="64Gi", gpu=None, From 79b47a2317505c6d167bb83559fa328440b447c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 19 May 2026 00:04:52 +0200 Subject: [PATCH 107/107] fix: preserve original wsi name --- ml/callbacks/tiff_prediction_map_writer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ml/callbacks/tiff_prediction_map_writer.py b/ml/callbacks/tiff_prediction_map_writer.py index d83b24d9..e8548dba 100644 --- a/ml/callbacks/tiff_prediction_map_writer.py +++ b/ml/callbacks/tiff_prediction_map_writer.py @@ -227,7 +227,7 @@ def _write_slide_maps( output_path: Path, class_names: list[str] | None, ) -> None: - filename = f"{_safe_filename(Path(str(slide['path'])).stem)}.tiff" + filename = _slide_prediction_filename(slide["path"]) extent = (int(slide["extent_x"]), int(slide["extent_y"])) tile_extent = (int(slide["tile_extent_x"]), int(slide["tile_extent_y"])) stride = (int(slide["stride_x"]), int(slide["stride_y"])) @@ -516,6 +516,10 @@ def _safe_filename(value: str) -> str: return sub(r"[^A-Za-z0-9 _.-]+", "_", value) +def _slide_prediction_filename(path: str | Path) -> str: + return Path(str(path)).with_suffix(".tiff").name + + def _spread_lut(n_classes: int) -> np.ndarray: """0-based class index -> evenly-spread uint8 value.