Skip to content

Commit b42d09d

Browse files
committed
Update throughput arguments and core assignment
1 parent c9c4130 commit b42d09d

5 files changed

Lines changed: 152 additions & 206 deletions

File tree

sklbench/benchmarks/throughput_worker.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ===============================================================================
2-
# Copyright 2024 Intel Corporation
2+
# Copyright 2026 Intel Corporation
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -18,14 +18,13 @@
1818
import inspect
1919
import json
2020
import socket
21-
import sys
2221
import time
2322
from typing import Dict, List, Tuple
2423

2524
from ..datasets import load_data
2625
from ..datasets.transformer import split_and_transform_data
26+
from ..utils.barrier import recv_until
2727
from ..utils.bench_case import get_bench_case_value
28-
from ..utils.common import convert_to_numpy
2928
from ..utils.config import bench_case_filter
3029
from ..utils.custom_types import BenchCase
3130
from ..utils.logger import logger
@@ -39,16 +38,6 @@
3938
)
4039

4140

42-
def barrier_wait(sock: socket.socket, msg_send: bytes, msg_expect_prefix: bytes):
43-
"""Send a message and block until response from parent."""
44-
sock.sendall(msg_send)
45-
data = b""
46-
while not data.startswith(msg_expect_prefix):
47-
chunk = sock.recv(1024)
48-
if not chunk:
49-
raise ConnectionError("Barrier socket closed unexpectedly")
50-
data += chunk
51-
5241

5342
def run_measurement_loop(
5443
func, args: tuple, measurement_duration: float
@@ -184,12 +173,7 @@ def main():
184173
continue
185174

186175
# Wait for "go" signal from parent before each stage
187-
data = b""
188-
while b"go" not in data:
189-
chunk = sock.recv(1024)
190-
if not chunk:
191-
raise ConnectionError("Barrier socket closed unexpectedly")
192-
data += chunk
176+
recv_until(sock, b"go")
193177

194178
method_name = available_methods[0]
195179
method_instance, data_args = get_method_and_args(

sklbench/runner/arguments.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -137,44 +137,14 @@ def add_runner_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentPa
137137
action="store_true",
138138
help="Interrupt runner and exit if last benchmark failed with error.",
139139
)
140-
# throughput mode arguments
140+
# throughput mode
141141
parser.add_argument(
142142
"--throughput-mode",
143143
default=False,
144144
action="store_true",
145145
help="Run in throughput mode: multiple synchronized parallel instances "
146-
"with CPU pinning via numactl.",
147-
)
148-
parser.add_argument(
149-
"--num-instances",
150-
type=int,
151-
default=None,
152-
help="Number of parallel instances in throughput mode.",
153-
)
154-
parser.add_argument(
155-
"--cores-per-instance",
156-
type=int,
157-
default=None,
158-
help="CPU cores per instance in throughput mode.",
159-
)
160-
parser.add_argument(
161-
"--measurement-duration",
162-
type=float,
163-
default=60.0,
164-
help="Duration (seconds) for each measurement stage in throughput mode.",
165-
)
166-
parser.add_argument(
167-
"--emergency-timeout",
168-
type=float,
169-
default=3600.0,
170-
help="Emergency subprocess timeout (seconds). Safety net only.",
171-
)
172-
parser.add_argument(
173-
"--throughput-full-logs",
174-
default=False,
175-
action="store_true",
176-
help="Store per-iteration start_ts and duration_ms arrays in throughput results. "
177-
"Disabled by default to reduce output size.",
146+
"with CPU pinning via numactl. Configure via bench:num_instances, "
147+
"bench:cores_per_instance, bench:measurement_duration in config.",
178148
)
179149
# option to get parser description in Markdown table format for READMEs
180150
parser.add_argument(

sklbench/runner/throughput.py

Lines changed: 25 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ===============================================================================
2-
# Copyright 2024 Intel Corporation
2+
# Copyright 2026 Intel Corporation
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -16,14 +16,14 @@
1616

1717
import argparse
1818
import json
19-
import socket
2019
import subprocess
2120
import time
2221
from typing import Dict, List, Tuple, Union
2322

2423
import numpy as np
2524
from tqdm import tqdm
2625

26+
from ..utils.barrier import accept_and_wait, create_server, send_all, wait_all
2727
from ..utils.bench_case import get_bench_case_name, get_bench_case_value
2828
from ..utils.common import custom_format, hash_from_json_repr
2929
from ..utils.core_assignment import compute_core_assignments
@@ -37,60 +37,14 @@ def validate_throughput_args(
3737
):
3838
if num_instances is None or num_instances < 1:
3939
raise ValueError(
40-
"--num-instances is required and must be >= 1 in throughput mode"
40+
"bench:num_instances is required and must be >= 1 in throughput mode"
4141
)
4242
if cores_per_instance is None or cores_per_instance < 1:
4343
raise ValueError(
44-
"--cores-per-instance is required and must be >= 1 in throughput mode"
44+
"bench:cores_per_instance is required and must be >= 1 in throughput mode"
4545
)
4646
if measurement_duration <= 0:
47-
raise ValueError("--measurement-duration must be > 0")
48-
49-
50-
def create_barrier_server() -> Tuple[socket.socket, int]:
51-
"""Create a TCP server socket on localhost with OS-assigned port."""
52-
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
53-
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
54-
server.bind(("localhost", 0))
55-
server.listen(128)
56-
port = server.getsockname()[1]
57-
return server, port
58-
59-
60-
def wait_for_workers_ready(
61-
server: socket.socket, num_instances: int, timeout: float
62-
) -> List[socket.socket]:
63-
"""Accept connections from all workers and wait for 'ready' message."""
64-
server.settimeout(timeout)
65-
connections = []
66-
for _ in range(num_instances):
67-
conn, _ = server.accept()
68-
data = b""
69-
while b"ready" not in data:
70-
chunk = conn.recv(1024)
71-
if not chunk:
72-
raise ConnectionError("Worker disconnected before sending 'ready'")
73-
data += chunk
74-
connections.append(conn)
75-
return connections
76-
77-
78-
def send_go_to_all(connections: List[socket.socket]):
79-
"""Send 'go' signal to all workers."""
80-
for conn in connections:
81-
conn.sendall(b"go")
82-
83-
84-
def wait_for_workers_done(connections: List[socket.socket], timeout: float):
85-
"""Wait for 'done' message from all workers."""
86-
for conn in connections:
87-
conn.settimeout(timeout)
88-
data = b""
89-
while b"done" not in data:
90-
chunk = conn.recv(1024)
91-
if not chunk:
92-
raise ConnectionError("Worker disconnected before sending 'done'")
93-
data += chunk
47+
raise ValueError("bench:measurement_duration must be > 0")
9448

9549

9650
def validate_sync_quality(instance_outputs: List[Dict], stage: str):
@@ -137,7 +91,6 @@ def aggregate_stage_results(
13791
stage: str,
13892
measurement_duration: float,
13993
core_assignments: List[str],
140-
full_logs: bool = False,
14194
) -> Dict:
14295
"""Aggregate per-instance results into a single stage result entry."""
14396
instances = []
@@ -163,10 +116,6 @@ def aggregate_stage_results(
163116
}
164117
instance_entry.update(compute_instance_stats(durations, start_timestamps))
165118

166-
if full_logs:
167-
instance_entry["start_ts"] = start_timestamps
168-
instance_entry["duration_ms"] = durations
169-
170119
instances.append(instance_entry)
171120
all_iterations.append(iters)
172121

@@ -196,19 +145,25 @@ def run_single_throughput_case(
196145
measurement_duration: float,
197146
emergency_timeout: float,
198147
log_level: str,
199-
full_logs: bool = False,
200148
) -> Tuple[int, List[Dict]]:
201149
"""Run a single benchmark case in throughput mode."""
150+
# Preload dataset in parent process to avoid cache race condition
151+
# when multiple workers try to download/generate and save simultaneously
152+
from ..datasets import load_data
153+
154+
logger.info("Preloading dataset in parent process to populate cache")
155+
load_data(bench_case)
156+
202157
numa_conf = get_numa_cpus_conf()
203158
core_assignments = compute_core_assignments(
204-
num_instances, cores_per_instance, numa_conf if numa_conf else None
159+
num_instances, cores_per_instance, numa_conf or None
205160
)
206161

207162
logger.info(
208163
f"Core assignments for {num_instances} instances: {core_assignments}"
209164
)
210165

211-
server, port = create_barrier_server()
166+
server, port = create_server()
212167
logger.debug(f"Barrier server listening on localhost:{port}")
213168

214169
bench_case_str = json.dumps(bench_case).replace(" ", "")
@@ -236,33 +191,18 @@ def run_single_throughput_case(
236191
processes.append(proc)
237192

238193
try:
239-
# Wait for all workers to be ready (prep phase - unlimited, but bounded by emergency timeout)
240-
connections = wait_for_workers_ready(server, num_instances, emergency_timeout)
194+
connections = accept_and_wait(server, num_instances, b"ready", emergency_timeout)
241195
logger.info("All workers ready, starting measurement stages")
242196

243-
# Determine which stages exist
244-
estimator_methods_training = get_bench_case_value(
245-
bench_case, "algorithm:estimator_methods:training", None
246-
)
247-
estimator_methods_inference = get_bench_case_value(
248-
bench_case, "algorithm:estimator_methods:inference", None
249-
)
250-
stages = []
251-
if estimator_methods_training is not None:
252-
stages = ["training", "inference"]
253-
else:
254-
# default stages
255-
stages = ["training", "inference"]
256-
257-
stage_timeout = measurement_duration + 60 # extra time for one stage
197+
stages = ["training", "inference"]
198+
stage_timeout = measurement_duration + 60
258199

259200
for stage in stages:
260201
logger.info(f"Sending 'go' for {stage} stage")
261-
send_go_to_all(connections)
262-
wait_for_workers_done(connections, stage_timeout)
202+
send_all(connections, b"go")
203+
wait_all(connections, b"done", stage_timeout)
263204
logger.info(f"All workers done with {stage} stage")
264205

265-
# Close barrier connections
266206
for conn in connections:
267207
conn.close()
268208

@@ -326,7 +266,7 @@ def run_single_throughput_case(
326266

327267
for stage in stages:
328268
stage_result = aggregate_stage_results(
329-
instance_outputs, stage, measurement_duration, core_assignments, full_logs
269+
instance_outputs, stage, measurement_duration, core_assignments
330270
)
331271
if not stage_result:
332272
continue
@@ -368,13 +308,6 @@ def run_throughput_benchmarks(
368308
env_info = get_environment_info()
369309
environment_name = args.environment_name or hash_from_json_repr(env_info)
370310

371-
# Resolve global defaults from CLI
372-
default_num_instances = args.num_instances
373-
default_cores_per_instance = args.cores_per_instance
374-
default_measurement_duration = args.measurement_duration
375-
default_emergency_timeout = args.emergency_timeout
376-
full_logs = args.throughput_full_logs
377-
378311
results = []
379312
return_code = 0
380313

@@ -386,18 +319,14 @@ def run_throughput_benchmarks(
386319
)
387320
)
388321

389-
# Per-case config overrides CLI defaults
390-
num_instances = get_bench_case_value(
391-
bench_case, "bench:num_instances", default_num_instances
392-
)
393-
cores_per_instance = get_bench_case_value(
394-
bench_case, "bench:cores_per_instance", default_cores_per_instance
395-
)
322+
# All throughput parameters come from bench_case config
323+
num_instances = get_bench_case_value(bench_case, "bench:num_instances")
324+
cores_per_instance = get_bench_case_value(bench_case, "bench:cores_per_instance")
396325
measurement_duration = get_bench_case_value(
397-
bench_case, "bench:measurement_duration", default_measurement_duration
326+
bench_case, "bench:measurement_duration", 60.0
398327
)
399328
emergency_timeout = get_bench_case_value(
400-
bench_case, "bench:emergency_timeout", default_emergency_timeout
329+
bench_case, "bench:emergency_timeout", 3600.0
401330
)
402331

403332
try:
@@ -420,7 +349,6 @@ def run_throughput_benchmarks(
420349
measurement_duration,
421350
emergency_timeout,
422351
args.bench_log_level,
423-
full_logs,
424352
)
425353
if case_return_code != 0:
426354
return_code = case_return_code

sklbench/utils/barrier.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# ===============================================================================
2+
# Copyright 2026 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ===============================================================================
16+
17+
"""TCP socket barrier for synchronizing throughput mode worker processes."""
18+
19+
import socket
20+
from typing import List, Tuple
21+
22+
23+
def create_server() -> Tuple[socket.socket, int]:
24+
"""Create a TCP server socket on localhost with OS-assigned port."""
25+
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
26+
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
27+
server.bind(("localhost", 0))
28+
server.listen(128)
29+
port = server.getsockname()[1]
30+
return server, port
31+
32+
33+
def recv_until(sock: socket.socket, expected: bytes):
34+
"""Block until expected message is received on socket."""
35+
data = b""
36+
while expected not in data:
37+
chunk = sock.recv(1024)
38+
if not chunk:
39+
raise ConnectionError(
40+
f"Socket closed before receiving {expected!r}"
41+
)
42+
data += chunk
43+
44+
45+
def send_all(connections: List[socket.socket], message: bytes):
46+
"""Send message to all connections."""
47+
for conn in connections:
48+
conn.sendall(message)
49+
50+
51+
def accept_and_wait(
52+
server: socket.socket, num_connections: int, expected: bytes, timeout: float
53+
) -> List[socket.socket]:
54+
"""Accept num_connections and wait for expected message from each."""
55+
server.settimeout(timeout)
56+
connections = []
57+
for _ in range(num_connections):
58+
conn, _ = server.accept()
59+
recv_until(conn, expected)
60+
connections.append(conn)
61+
return connections
62+
63+
64+
def wait_all(connections: List[socket.socket], expected: bytes, timeout: float):
65+
"""Wait for expected message from all existing connections."""
66+
for conn in connections:
67+
conn.settimeout(timeout)
68+
recv_until(conn, expected)

0 commit comments

Comments
 (0)