Skip to content

Commit 9c40b0d

Browse files
committed
Add show_progress parameter to predict
1 parent 49abc72 commit 9c40b0d

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

deeplc/_model_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,12 @@ def predict(
148148
device: str = "cpu",
149149
batch_size: int = 512,
150150
num_workers: int = 0,
151+
show_progress: bool = True,
151152
) -> torch.Tensor:
152153
"""Predict using the model for the given dataset."""
153154
model = load_model(model, device)
154155
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
155-
predictions = _predict_epoch(model, data_loader, device, show_progress=True)
156+
predictions = _predict_epoch(model, data_loader, device, show_progress=show_progress)
156157
return predictions.cpu().detach()
157158

158159

0 commit comments

Comments
 (0)