Skip to content

Commit 7e251b6

Browse files
authored
Add concurrent insert in performence case (#741)
1. Fix concurrent insert memory and process cleanup 2. Add configurable load concurrency for performance cases 3. Make CLI Ctrl+C work by polling has_running() instead of blocking on concurrent.futures.wait(), which swallows SIGINT. 4. Remove perf-case insert from SerialInsertRunner 5. Ignore S608 lint rule and fix formatting Signed-off-by: yangxuan <xuan.yang@zilliz.com>
1 parent 6f7a153 commit 7e251b6

22 files changed

Lines changed: 730 additions & 152 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ lint.ignore = [
126126
"INP001", # TODO
127127
"TID252", # TODO
128128
"N801", "N802", "N815",
129-
"S101", "S108", "S603", "S311",
129+
"S101", "S108", "S603", "S311", "S608",
130130
"PLR2004",
131131
"RUF017",
132132
"C416",

tests/test_concurrent_runner.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Tests for ConcurrentInsertRunner against a running Milvus instance.
2+
3+
Includes:
4+
- Correctness tests (threading & async backends)
5+
- Parameterized benchmark: serial vs concurrent across (batch_size, workers) matrix
6+
7+
NUM_PER_BATCH is set via os.environ before each run. Since runners execute
8+
task() in a spawn subprocess that re-imports config, the env var takes effect.
9+
10+
Requires:
11+
- Milvus running at localhost:19530
12+
- Network access to download OpenAI 50K dataset
13+
14+
Usage:
15+
pytest tests/test_concurrent_runner.py -v -s # correctness tests only
16+
python tests/test_concurrent_runner.py # full benchmark matrix
17+
"""
18+
19+
# ruff: noqa: T201
20+
21+
from __future__ import annotations
22+
23+
import logging
24+
import os
25+
import time
26+
27+
from vectordb_bench.backend.clients import DB
28+
from vectordb_bench.backend.clients.milvus.config import FLATConfig
29+
from vectordb_bench.backend.dataset import Dataset, DatasetSource
30+
from vectordb_bench.backend.runner.concurrent_runner import ConcurrentInsertRunner, ExecutorBackend
31+
from vectordb_bench.backend.runner.serial_runner import SerialInsertRunner
32+
33+
log = logging.getLogger("vectordb_bench")
34+
log.setLevel(logging.INFO)
35+
36+
DATASET_SIZE = 50_000
37+
38+
39+
# ── Shared helpers ──────────────────────────────────────────────────────
40+
41+
42+
def get_milvus_db(collection_name: str):
43+
return DB.Milvus.init_cls(
44+
dim=1536,
45+
db_config={"uri": "http://localhost:19530", "user": "", "password": ""},
46+
db_case_config=FLATConfig(metric_type="COSINE"),
47+
collection_name=collection_name,
48+
drop_old=True,
49+
)
50+
51+
52+
def prepare_dataset():
53+
dataset = Dataset.OPENAI.manager(DATASET_SIZE)
54+
dataset.prepare(DatasetSource.AliyunOSS)
55+
return dataset
56+
57+
58+
def set_batch_size(batch_size: int) -> None:
59+
os.environ["NUM_PER_BATCH"] = str(batch_size)
60+
61+
62+
def timed_run(runner: SerialInsertRunner | ConcurrentInsertRunner) -> tuple[int, float]:
63+
start = time.perf_counter()
64+
count = runner.run()
65+
return count, time.perf_counter() - start
66+
67+
68+
# ── Correctness tests (pytest) ──────────────────────────────────────────
69+
70+
71+
def test_concurrent_insert_threading():
72+
"""Test concurrent insert with threading backend."""
73+
db = get_milvus_db("test_conc_threading")
74+
runner = ConcurrentInsertRunner(
75+
db=db,
76+
dataset=prepare_dataset(),
77+
normalize=False,
78+
max_workers=4,
79+
backend=ExecutorBackend.THREADING,
80+
)
81+
count = runner.run()
82+
assert count == DATASET_SIZE, f"Expected {DATASET_SIZE}, got {count}"
83+
84+
85+
def test_concurrent_insert_async():
86+
"""Test concurrent insert with async backend."""
87+
db = get_milvus_db("test_conc_async")
88+
runner = ConcurrentInsertRunner(
89+
db=db,
90+
dataset=prepare_dataset(),
91+
normalize=False,
92+
max_workers=4,
93+
backend=ExecutorBackend.ASYNC,
94+
)
95+
count = runner.run()
96+
assert count == DATASET_SIZE, f"Expected {DATASET_SIZE}, got {count}"
97+
98+
99+
# ── Parameterized benchmark ────────────────────────────────────────────
100+
101+
102+
def run_serial(batch_size: int) -> tuple[int, float]:
103+
set_batch_size(batch_size)
104+
runner = SerialInsertRunner(
105+
db=get_milvus_db(f"bench_serial_b{batch_size}"),
106+
dataset=prepare_dataset(),
107+
normalize=False,
108+
)
109+
return timed_run(runner)
110+
111+
112+
def run_concurrent(batch_size: int, workers: int) -> tuple[int, float]:
113+
set_batch_size(batch_size)
114+
runner = ConcurrentInsertRunner(
115+
db=get_milvus_db(f"bench_conc_b{batch_size}_w{workers}"),
116+
dataset=prepare_dataset(),
117+
normalize=False,
118+
max_workers=workers,
119+
backend=ExecutorBackend.THREADING,
120+
)
121+
return timed_run(runner)
122+
123+
124+
def bench_matrix():
125+
batch_sizes = [100, 500, 1000, 5000]
126+
worker_counts = [1, 2, 4, 8]
127+
128+
conc_headers = [f"conc({w}w)" for w in worker_counts]
129+
speedup_headers = [f"speedup({w}w)" for w in worker_counts]
130+
print(f"\n{'Batch':>6} {'#Bat':>5} {'serial':>8}", end="")
131+
for h in conc_headers:
132+
print(f" {h:>10}", end="")
133+
for h in speedup_headers:
134+
print(f" {h:>12}", end="")
135+
print()
136+
print("-" * (22 + 10 * len(worker_counts) + 12 * len(worker_counts)))
137+
138+
for bs in batch_sizes:
139+
n_batches = DATASET_SIZE // bs
140+
_, dur_s = run_serial(bs)
141+
142+
conc_durs = []
143+
for w in worker_counts:
144+
_, dur_c = run_concurrent(bs, w)
145+
conc_durs.append(dur_c)
146+
147+
print(f"{bs:>6} {n_batches:>5} {dur_s:>7.2f}s", end="")
148+
for dur_c in conc_durs:
149+
print(f" {dur_c:>9.2f}s", end="")
150+
for dur_c in conc_durs:
151+
print(f" {dur_s / dur_c:>11.2f}x", end="")
152+
print()
153+
154+
# restore default
155+
set_batch_size(100)
156+
157+
158+
if __name__ == "__main__":
159+
bench_matrix()

vectordb_bench/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class config:
2020
DATASET_SOURCE = env.str("DATASET_SOURCE", "S3") # Options "S3" or "AliyunOSS"
2121
DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset")
2222
NUM_PER_BATCH = env.int("NUM_PER_BATCH", 100)
23+
LOAD_CONCURRENCY = env.int("LOAD_CONCURRENCY", 0) # 0 = cpu_count
2324
TIME_PER_BATCH = 1 # 1s. for streaming insertion.
2425
MAX_INSERT_RETRY = 5
2526
MAX_SEARCH_RETRY = 5

vectordb_bench/backend/clients/alisql/alisql.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,13 @@ def init(self):
107107
self.cursor.execute(f"SET SESSION vidx_hnsw_ef_search = {search_param['ef_search']}")
108108
self.cursor.execute("COMMIT")
109109

110-
self.insert_sql = (
111-
f'INSERT INTO {self.db_config["database"]}.{self.table_name} (id, v) VALUES (%s, %s)' # noqa: S608
112-
)
110+
self.insert_sql = f'INSERT INTO {self.db_config["database"]}.{self.table_name} (id, v) VALUES (%s, %s)'
113111
self.select_sql = (
114-
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} ' # noqa: S608
112+
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} '
115113
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s"
116114
)
117115
self.select_sql_with_filter = (
118-
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} WHERE id >= %s ' # noqa: S608
116+
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} WHERE id >= %s '
119117
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s"
120118
)
121119

vectordb_bench/backend/clients/api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ class VectorDB(ABC):
140140
supported_filter_types: list[FilterOp] = [FilterOp.NonFilter]
141141
name: str = ""
142142

143+
# Whether the client can share a single connection across threads.
144+
# If False, concurrent runners will deep-copy the instance and call
145+
# init() per thread instead of sharing the parent connection.
146+
thread_safe: bool = True
147+
143148
@classmethod
144149
def filter_supported(cls, filters: Filter) -> bool:
145150
"""Ensure that the filters are supported before testing filtering cases."""

vectordb_bench/backend/clients/doris/doris.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414

1515
class Doris(VectorDB):
16+
thread_safe: bool = False
17+
1618
def __init__(
1719
self,
1820
dim: int,

vectordb_bench/backend/clients/mariadb/mariadb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,13 @@ def init(self):
108108
self.cursor.execute(f"SET mhnsw_ef_search = {search_param['ef_search']}")
109109
self.cursor.execute("COMMIT")
110110

111-
self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608
111+
self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)"
112112
self.select_sql = (
113-
f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608
113+
f"SELECT id FROM {self.db_name}.{self.table_name}"
114114
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
115115
)
116116
self.select_sql_with_filter = (
117-
f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " # noqa: S608
117+
f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d "
118118
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
119119
)
120120

vectordb_bench/backend/clients/milvus/milvus.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ def init(self):
137137
self.client.close()
138138
self.client = None
139139

140+
def _wait_for_segments_sorted(self):
141+
while True:
142+
segments = self.client.list_persistent_segments(self.collection_name)
143+
unsorted = [s for s in segments if not s.is_sorted]
144+
if not unsorted:
145+
log.info(f"{self.name} all persistent segments are sorted.")
146+
break
147+
log.debug(f"{self.name} waiting for {len(unsorted)} segments to be sorted...")
148+
time.sleep(5)
149+
140150
def _wait_for_index(self):
141151
while True:
142152
info = self.client.describe_index(self.collection_name, self._vector_index_name)
@@ -155,6 +165,7 @@ def _optimize(self):
155165
log.info(f"{self.name} optimizing before search")
156166
try:
157167
self.client.flush(self.collection_name)
168+
self._wait_for_segments_sorted()
158169
self._wait_for_index()
159170
if self.case_config.is_gpu_index:
160171
log.debug("skip force merge compaction for gpu index type.")

vectordb_bench/backend/clients/oceanbase/oceanbase.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def insert_embeddings(
186186
batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)]
187187
values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch)
188188
self._cursor.execute(
189-
f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}" # noqa: S608
189+
f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}"
190190
)
191191
insert_count += len(batch)
192192
except mysql.Error:
@@ -217,7 +217,7 @@ def search_embedding(
217217
packed = struct.pack(f"<{len(query)}f", *query)
218218
hex_vec = packed.hex()
219219
query_str = (
220-
f"SELECT id FROM {self.table_name} " # noqa: S608
220+
f"SELECT id FROM {self.table_name} "
221221
f"{self.expr} ORDER BY "
222222
f"{self.db_case_config.parse_metric_func_str()}(embedding, X'{hex_vec}') "
223223
f"APPROXIMATE LIMIT {k}"

vectordb_bench/backend/clients/pgvector/pgvector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
class PgVector(VectorDB):
2222
"""Use psycopg instructions"""
2323

24+
thread_safe: bool = False
2425
supported_filter_types: list[FilterOp] = [
2526
FilterOp.NonFilter,
2627
FilterOp.NumGE,

0 commit comments

Comments
 (0)