@@ -147,18 +147,47 @@ def test_time_inference(monkeypatch):
147147 assert result ["mean_inference_time_per_prediction" ] == pytest .approx (3.0 / 2 )
148148
149149
150- def test_evaluate_compute (monkeypatch ):
151- if not torch . cuda . is_available ():
152- pytest . skip ( "CUDA unavailable, skipping measure_memory_footprint test." )
150+ def test_evaluate_compute_cpu_path (monkeypatch ):
151+ # Force non-CUDA execution
152+ monkeypatch . setattr ( torch . cuda , "is_available" , lambda : False )
153153
154- # patch memory footprint and parameter count
154+ # measure_memory_footprint should not be called
155+ called = {"mem" : False }
156+
157+ def fake_measure (model , inputs , device ):
158+ called ["mem" ] = True
159+ return {}, model
160+
161+ monkeypatch .setattr (bf , "measure_memory_footprint" , fake_measure )
162+ monkeypatch .setattr (bf , "count_trainable_parameters" , lambda m : 12345 )
163+
164+ class DummyLoader :
165+ def __iter__ (self ):
166+ yield ("inp" ,)
167+
168+ loader = DummyLoader ()
169+ model = FakeModel ()
170+ surr = "SurrB"
171+ conf = {"training_id" : "TID" }
172+ out = bf .evaluate_compute (model , surr , test_loader = loader , conf = conf )
173+ assert model .load_calls == [("TID" , surr , f"{ surr .lower ()} _main" )]
174+ assert out ["num_trainable_parameters" ] == 12345
175+ assert out ["memory_footprint" ] == {}
176+ assert not called ["mem" ]
177+
178+
179+ @pytest .mark .skipif (
180+ not torch .cuda .is_available (), reason = "CUDA unavailable for memory profiling test."
181+ )
182+ def test_evaluate_compute_cuda_path (monkeypatch ):
155183 fake_mem = {"model_memory" : 100 , "forward_memory_nograd" : 50 }
156- monkeypatch .setattr (
157- bf , "measure_memory_footprint" , lambda m , inp , device : (fake_mem , m )
158- )
184+
185+ def fake_measure (model , inputs , device ):
186+ return fake_mem , model
187+
188+ monkeypatch .setattr (bf , "measure_memory_footprint" , fake_measure )
159189 monkeypatch .setattr (bf , "count_trainable_parameters" , lambda m : 12345 )
160190
161- # test_loader yields one tuple of inputs
162191 class DummyLoader :
163192 def __iter__ (self ):
164193 yield ("inp" ,)
@@ -168,7 +197,6 @@ def __iter__(self):
168197 surr = "SurrB"
169198 conf = {"training_id" : "TID" }
170199 out = bf .evaluate_compute (model , surr , test_loader = loader , conf = conf )
171- # load main was invoked
172200 assert model .load_calls == [("TID" , surr , f"{ surr .lower ()} _main" )]
173201 assert out ["num_trainable_parameters" ] == 12345
174202 assert out ["memory_footprint" ] is fake_mem
0 commit comments