Skip to content

Commit 8211ca4

Browse files
Sort inference weight lists
1 parent 2e0e753 commit 8211ca4

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

src/spine/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def inference_single(cfg: dict) -> None:
136136
preloaded = True
137137
weights = [weights]
138138
else:
139+
weights = sorted(weights)
139140
weight_list = " - " + "\n - ".join(weights)
140141
logger.info(
141142
"Looping over %d set of weights:\n%s", len(weights), weight_list

test/test_conditional_imports.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,43 @@ def run(self):
203203

204204
assert calls == [cfg]
205205

206+
def test_inference_weight_list_runs_in_sorted_order(self, monkeypatch):
207+
"""Test multi-weight inference uses deterministic lexical ordering."""
208+
from spine import main
209+
210+
calls = []
211+
212+
class MockModel:
213+
weight_path = ["weights/b.ckpt", "weights/a.ckpt"]
214+
215+
def load_weights(self, weight_path):
216+
calls.append(("load", weight_path))
217+
218+
class MockDriver:
219+
model = MockModel()
220+
221+
def __init__(self, cfg):
222+
self.cfg = cfg
223+
224+
def initialize_log(self):
225+
calls.append(("log", None))
226+
227+
def run(self):
228+
calls.append(("run", None))
229+
230+
monkeypatch.setattr(main, "Driver", MockDriver)
231+
232+
main.inference_single({"base": {}, "io": {"reader": {"name": "hdf5"}}})
233+
234+
assert calls == [
235+
("load", "weights/a.ckpt"),
236+
("log", None),
237+
("run", None),
238+
("load", "weights/b.ckpt"),
239+
("log", None),
240+
("run", None),
241+
]
242+
206243
def test_cli_import_and_version(self):
207244
"""Test CLI imports and version detection works."""
208245
from spine.bin.cli import check_dependencies, get_version, main

0 commit comments

Comments
 (0)