Skip to content

Commit 4f4d64a

Browse files
Donglai Weiclaude
andcommitted
Skip real datamodule in decode-only mode, use dummy single-batch loader
When saved_prediction_path is set, creates a minimal datamodule with one dummy batch instead of loading test images. The actual prediction is loaded from the H5 file in _load_cached_predictions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1322a9a commit 4f4d64a

1 file changed

Lines changed: 34 additions & 2 deletions

File tree

scripts/main.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,35 @@ def _has_tta_prediction_file(cfg: Config) -> bool:
275275
return os.path.exists(pred_file) and _is_valid_hdf5_prediction_file(pred_file)
276276

277277

278+
def _create_decode_only_datamodule(cfg, saved_prediction_path: str):
279+
"""Create a minimal datamodule for decode-only mode.
280+
281+
Yields a single dummy batch so that trainer.test() triggers test_step
282+
once. The actual prediction is loaded from saved_prediction_path inside
283+
_load_cached_predictions.
284+
"""
285+
from pathlib import Path
286+
287+
import pytorch_lightning as pl
288+
from torch.utils.data import DataLoader, Dataset
289+
290+
pred_stem = Path(saved_prediction_path).stem
291+
292+
class _DummyDataset(Dataset):
293+
def __len__(self):
294+
return 1
295+
296+
def __getitem__(self, idx):
297+
# Return minimal dict that test_step expects
298+
return {"image": torch.zeros(1, 1, 1, 1), "filename": pred_stem}
299+
300+
class _DummyDataModule(pl.LightningDataModule):
301+
def test_dataloader(self):
302+
return DataLoader(_DummyDataset(), batch_size=1)
303+
304+
return _DummyDataModule()
305+
306+
278307
def _has_cached_predictions_in_output_dir(
279308
cfg: Config, mode: str, checkpoint_path: str | None = None
280309
) -> bool:
@@ -1003,8 +1032,11 @@ def main():
10031032
if not has_assigned_test_shard(cfg, args):
10041033
return
10051034

1006-
# Create datamodule
1007-
datamodule = create_datamodule(cfg, mode="test")
1035+
# Create datamodule (or dummy for decode-only mode)
1036+
if has_saved_prediction:
1037+
datamodule = _create_decode_only_datamodule(cfg, _saved_pred)
1038+
else:
1039+
datamodule = create_datamodule(cfg, mode="test")
10081040

10091041
# Apply test volume sharding across machines
10101042
if args.shard_id is not None and args.num_shards is not None:

0 commit comments

Comments
 (0)