11# ===============================================================================
2- # Copyright 2024 Intel Corporation
2+ # Copyright 2026 Intel Corporation
33#
44# Licensed under the Apache License, Version 2.0 (the "License");
55# you may not use this file except in compliance with the License.
1616
1717import argparse
1818import json
19- import socket
2019import subprocess
2120import time
2221from typing import Dict , List , Tuple , Union
2322
2423import numpy as np
2524from tqdm import tqdm
2625
26+ from ..utils .barrier import accept_and_wait , create_server , send_all , wait_all
2727from ..utils .bench_case import get_bench_case_name , get_bench_case_value
2828from ..utils .common import custom_format , hash_from_json_repr
2929from ..utils .core_assignment import compute_core_assignments
@@ -37,60 +37,14 @@ def validate_throughput_args(
3737):
3838 if num_instances is None or num_instances < 1 :
3939 raise ValueError (
40- "--num-instances is required and must be >= 1 in throughput mode"
40+ "bench:num_instances is required and must be >= 1 in throughput mode"
4141 )
4242 if cores_per_instance is None or cores_per_instance < 1 :
4343 raise ValueError (
44- "--cores-per-instance is required and must be >= 1 in throughput mode"
44+ "bench:cores_per_instance is required and must be >= 1 in throughput mode"
4545 )
4646 if measurement_duration <= 0 :
47- raise ValueError ("--measurement-duration must be > 0" )
48-
49-
50- def create_barrier_server () -> Tuple [socket .socket , int ]:
51- """Create a TCP server socket on localhost with OS-assigned port."""
52- server = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
53- server .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
54- server .bind (("localhost" , 0 ))
55- server .listen (128 )
56- port = server .getsockname ()[1 ]
57- return server , port
58-
59-
60- def wait_for_workers_ready (
61- server : socket .socket , num_instances : int , timeout : float
62- ) -> List [socket .socket ]:
63- """Accept connections from all workers and wait for 'ready' message."""
64- server .settimeout (timeout )
65- connections = []
66- for _ in range (num_instances ):
67- conn , _ = server .accept ()
68- data = b""
69- while b"ready" not in data :
70- chunk = conn .recv (1024 )
71- if not chunk :
72- raise ConnectionError ("Worker disconnected before sending 'ready'" )
73- data += chunk
74- connections .append (conn )
75- return connections
76-
77-
78- def send_go_to_all (connections : List [socket .socket ]):
79- """Send 'go' signal to all workers."""
80- for conn in connections :
81- conn .sendall (b"go" )
82-
83-
84- def wait_for_workers_done (connections : List [socket .socket ], timeout : float ):
85- """Wait for 'done' message from all workers."""
86- for conn in connections :
87- conn .settimeout (timeout )
88- data = b""
89- while b"done" not in data :
90- chunk = conn .recv (1024 )
91- if not chunk :
92- raise ConnectionError ("Worker disconnected before sending 'done'" )
93- data += chunk
47+ raise ValueError ("bench:measurement_duration must be > 0" )
9448
9549
9650def validate_sync_quality (instance_outputs : List [Dict ], stage : str ):
@@ -137,7 +91,6 @@ def aggregate_stage_results(
13791 stage : str ,
13892 measurement_duration : float ,
13993 core_assignments : List [str ],
140- full_logs : bool = False ,
14194) -> Dict :
14295 """Aggregate per-instance results into a single stage result entry."""
14396 instances = []
@@ -163,10 +116,6 @@ def aggregate_stage_results(
163116 }
164117 instance_entry .update (compute_instance_stats (durations , start_timestamps ))
165118
166- if full_logs :
167- instance_entry ["start_ts" ] = start_timestamps
168- instance_entry ["duration_ms" ] = durations
169-
170119 instances .append (instance_entry )
171120 all_iterations .append (iters )
172121
@@ -196,19 +145,25 @@ def run_single_throughput_case(
196145 measurement_duration : float ,
197146 emergency_timeout : float ,
198147 log_level : str ,
199- full_logs : bool = False ,
200148) -> Tuple [int , List [Dict ]]:
201149 """Run a single benchmark case in throughput mode."""
150+ # Preload dataset in parent process to avoid cache race condition
151+ # when multiple workers try to download/generate and save simultaneously
152+ from ..datasets import load_data
153+
154+ logger .info ("Preloading dataset in parent process to populate cache" )
155+ load_data (bench_case )
156+
202157 numa_conf = get_numa_cpus_conf ()
203158 core_assignments = compute_core_assignments (
204- num_instances , cores_per_instance , numa_conf if numa_conf else None
159+ num_instances , cores_per_instance , numa_conf or None
205160 )
206161
207162 logger .info (
208163 f"Core assignments for { num_instances } instances: { core_assignments } "
209164 )
210165
211- server , port = create_barrier_server ()
166+ server , port = create_server ()
212167 logger .debug (f"Barrier server listening on localhost:{ port } " )
213168
214169 bench_case_str = json .dumps (bench_case ).replace (" " , "" )
@@ -236,33 +191,18 @@ def run_single_throughput_case(
236191 processes .append (proc )
237192
238193 try :
239- # Wait for all workers to be ready (prep phase - unlimited, but bounded by emergency timeout)
240- connections = wait_for_workers_ready (server , num_instances , emergency_timeout )
194+ connections = accept_and_wait (server , num_instances , b"ready" , emergency_timeout )
241195 logger .info ("All workers ready, starting measurement stages" )
242196
243- # Determine which stages exist
244- estimator_methods_training = get_bench_case_value (
245- bench_case , "algorithm:estimator_methods:training" , None
246- )
247- estimator_methods_inference = get_bench_case_value (
248- bench_case , "algorithm:estimator_methods:inference" , None
249- )
250- stages = []
251- if estimator_methods_training is not None :
252- stages = ["training" , "inference" ]
253- else :
254- # default stages
255- stages = ["training" , "inference" ]
256-
257- stage_timeout = measurement_duration + 60 # extra time for one stage
197+ stages = ["training" , "inference" ]
198+ stage_timeout = measurement_duration + 60
258199
259200 for stage in stages :
260201 logger .info (f"Sending 'go' for { stage } stage" )
261- send_go_to_all (connections )
262- wait_for_workers_done (connections , stage_timeout )
202+ send_all (connections , b"go" )
203+ wait_all (connections , b"done" , stage_timeout )
263204 logger .info (f"All workers done with { stage } stage" )
264205
265- # Close barrier connections
266206 for conn in connections :
267207 conn .close ()
268208
@@ -326,7 +266,7 @@ def run_single_throughput_case(
326266
327267 for stage in stages :
328268 stage_result = aggregate_stage_results (
329- instance_outputs , stage , measurement_duration , core_assignments , full_logs
269+ instance_outputs , stage , measurement_duration , core_assignments
330270 )
331271 if not stage_result :
332272 continue
@@ -368,13 +308,6 @@ def run_throughput_benchmarks(
368308 env_info = get_environment_info ()
369309 environment_name = args .environment_name or hash_from_json_repr (env_info )
370310
371- # Resolve global defaults from CLI
372- default_num_instances = args .num_instances
373- default_cores_per_instance = args .cores_per_instance
374- default_measurement_duration = args .measurement_duration
375- default_emergency_timeout = args .emergency_timeout
376- full_logs = args .throughput_full_logs
377-
378311 results = []
379312 return_code = 0
380313
@@ -386,18 +319,14 @@ def run_throughput_benchmarks(
386319 )
387320 )
388321
389- # Per-case config overrides CLI defaults
390- num_instances = get_bench_case_value (
391- bench_case , "bench:num_instances" , default_num_instances
392- )
393- cores_per_instance = get_bench_case_value (
394- bench_case , "bench:cores_per_instance" , default_cores_per_instance
395- )
322+ # All throughput parameters come from bench_case config
323+ num_instances = get_bench_case_value (bench_case , "bench:num_instances" )
324+ cores_per_instance = get_bench_case_value (bench_case , "bench:cores_per_instance" )
396325 measurement_duration = get_bench_case_value (
397- bench_case , "bench:measurement_duration" , default_measurement_duration
326+ bench_case , "bench:measurement_duration" , 60.0
398327 )
399328 emergency_timeout = get_bench_case_value (
400- bench_case , "bench:emergency_timeout" , default_emergency_timeout
329+ bench_case , "bench:emergency_timeout" , 3600.0
401330 )
402331
403332 try :
@@ -420,7 +349,6 @@ def run_throughput_benchmarks(
420349 measurement_duration ,
421350 emergency_timeout ,
422351 args .bench_log_level ,
423- full_logs ,
424352 )
425353 if case_return_code != 0 :
426354 return_code = case_return_code
0 commit comments