|
| 1 | +"""Tests for ConcurrentInsertRunner against a running Milvus instance. |
| 2 | +
|
| 3 | +Includes: |
| 4 | + - Correctness tests (threading & async backends) |
| 5 | + - Parameterized benchmark: serial vs concurrent across (batch_size, workers) matrix |
| 6 | +
|
| 7 | +NUM_PER_BATCH is set via os.environ before each run. Since runners execute |
| 8 | +task() in a spawn subprocess that re-imports config, the env var takes effect. |
| 9 | +
|
| 10 | +Requires: |
| 11 | + - Milvus running at localhost:19530 |
| 12 | + - Network access to download OpenAI 50K dataset |
| 13 | +
|
| 14 | +Usage: |
| 15 | + pytest tests/test_concurrent_runner.py -v -s # correctness tests only |
| 16 | + python tests/test_concurrent_runner.py # full benchmark matrix |
| 17 | +""" |
| 18 | + |
| 19 | +# ruff: noqa: T201 |
| 20 | + |
| 21 | +from __future__ import annotations |
| 22 | + |
| 23 | +import logging |
| 24 | +import os |
| 25 | +import time |
| 26 | + |
| 27 | +from vectordb_bench.backend.clients import DB |
| 28 | +from vectordb_bench.backend.clients.milvus.config import FLATConfig |
| 29 | +from vectordb_bench.backend.dataset import Dataset, DatasetSource |
| 30 | +from vectordb_bench.backend.runner.concurrent_runner import ConcurrentInsertRunner, ExecutorBackend |
| 31 | +from vectordb_bench.backend.runner.serial_runner import SerialInsertRunner |
| 32 | + |
| 33 | +log = logging.getLogger("vectordb_bench") |
| 34 | +log.setLevel(logging.INFO) |
| 35 | + |
| 36 | +DATASET_SIZE = 50_000 |
| 37 | + |
| 38 | + |
| 39 | +# ── Shared helpers ────────────────────────────────────────────────────── |
| 40 | + |
| 41 | + |
| 42 | +def get_milvus_db(collection_name: str): |
| 43 | + return DB.Milvus.init_cls( |
| 44 | + dim=1536, |
| 45 | + db_config={"uri": "http://localhost:19530", "user": "", "password": ""}, |
| 46 | + db_case_config=FLATConfig(metric_type="COSINE"), |
| 47 | + collection_name=collection_name, |
| 48 | + drop_old=True, |
| 49 | + ) |
| 50 | + |
| 51 | + |
| 52 | +def prepare_dataset(): |
| 53 | + dataset = Dataset.OPENAI.manager(DATASET_SIZE) |
| 54 | + dataset.prepare(DatasetSource.AliyunOSS) |
| 55 | + return dataset |
| 56 | + |
| 57 | + |
| 58 | +def set_batch_size(batch_size: int) -> None: |
| 59 | + os.environ["NUM_PER_BATCH"] = str(batch_size) |
| 60 | + |
| 61 | + |
| 62 | +def timed_run(runner: SerialInsertRunner | ConcurrentInsertRunner) -> tuple[int, float]: |
| 63 | + start = time.perf_counter() |
| 64 | + count = runner.run() |
| 65 | + return count, time.perf_counter() - start |
| 66 | + |
| 67 | + |
| 68 | +# ── Correctness tests (pytest) ────────────────────────────────────────── |
| 69 | + |
| 70 | + |
| 71 | +def test_concurrent_insert_threading(): |
| 72 | + """Test concurrent insert with threading backend.""" |
| 73 | + db = get_milvus_db("test_conc_threading") |
| 74 | + runner = ConcurrentInsertRunner( |
| 75 | + db=db, |
| 76 | + dataset=prepare_dataset(), |
| 77 | + normalize=False, |
| 78 | + max_workers=4, |
| 79 | + backend=ExecutorBackend.THREADING, |
| 80 | + ) |
| 81 | + count = runner.run() |
| 82 | + assert count == DATASET_SIZE, f"Expected {DATASET_SIZE}, got {count}" |
| 83 | + |
| 84 | + |
| 85 | +def test_concurrent_insert_async(): |
| 86 | + """Test concurrent insert with async backend.""" |
| 87 | + db = get_milvus_db("test_conc_async") |
| 88 | + runner = ConcurrentInsertRunner( |
| 89 | + db=db, |
| 90 | + dataset=prepare_dataset(), |
| 91 | + normalize=False, |
| 92 | + max_workers=4, |
| 93 | + backend=ExecutorBackend.ASYNC, |
| 94 | + ) |
| 95 | + count = runner.run() |
| 96 | + assert count == DATASET_SIZE, f"Expected {DATASET_SIZE}, got {count}" |
| 97 | + |
| 98 | + |
| 99 | +# ── Parameterized benchmark ──────────────────────────────────────────── |
| 100 | + |
| 101 | + |
| 102 | +def run_serial(batch_size: int) -> tuple[int, float]: |
| 103 | + set_batch_size(batch_size) |
| 104 | + runner = SerialInsertRunner( |
| 105 | + db=get_milvus_db(f"bench_serial_b{batch_size}"), |
| 106 | + dataset=prepare_dataset(), |
| 107 | + normalize=False, |
| 108 | + ) |
| 109 | + return timed_run(runner) |
| 110 | + |
| 111 | + |
| 112 | +def run_concurrent(batch_size: int, workers: int) -> tuple[int, float]: |
| 113 | + set_batch_size(batch_size) |
| 114 | + runner = ConcurrentInsertRunner( |
| 115 | + db=get_milvus_db(f"bench_conc_b{batch_size}_w{workers}"), |
| 116 | + dataset=prepare_dataset(), |
| 117 | + normalize=False, |
| 118 | + max_workers=workers, |
| 119 | + backend=ExecutorBackend.THREADING, |
| 120 | + ) |
| 121 | + return timed_run(runner) |
| 122 | + |
| 123 | + |
| 124 | +def bench_matrix(): |
| 125 | + batch_sizes = [100, 500, 1000, 5000] |
| 126 | + worker_counts = [1, 2, 4, 8] |
| 127 | + |
| 128 | + conc_headers = [f"conc({w}w)" for w in worker_counts] |
| 129 | + speedup_headers = [f"speedup({w}w)" for w in worker_counts] |
| 130 | + print(f"\n{'Batch':>6} {'#Bat':>5} {'serial':>8}", end="") |
| 131 | + for h in conc_headers: |
| 132 | + print(f" {h:>10}", end="") |
| 133 | + for h in speedup_headers: |
| 134 | + print(f" {h:>12}", end="") |
| 135 | + print() |
| 136 | + print("-" * (22 + 10 * len(worker_counts) + 12 * len(worker_counts))) |
| 137 | + |
| 138 | + for bs in batch_sizes: |
| 139 | + n_batches = DATASET_SIZE // bs |
| 140 | + _, dur_s = run_serial(bs) |
| 141 | + |
| 142 | + conc_durs = [] |
| 143 | + for w in worker_counts: |
| 144 | + _, dur_c = run_concurrent(bs, w) |
| 145 | + conc_durs.append(dur_c) |
| 146 | + |
| 147 | + print(f"{bs:>6} {n_batches:>5} {dur_s:>7.2f}s", end="") |
| 148 | + for dur_c in conc_durs: |
| 149 | + print(f" {dur_c:>9.2f}s", end="") |
| 150 | + for dur_c in conc_durs: |
| 151 | + print(f" {dur_s / dur_c:>11.2f}x", end="") |
| 152 | + print() |
| 153 | + |
| 154 | + # restore default |
| 155 | + set_batch_size(100) |
| 156 | + |
| 157 | + |
| 158 | +if __name__ == "__main__": |
| 159 | + bench_matrix() |
0 commit comments