Skip to content

Commit 7bda989

Browse files
lucagiac81alwayslove2013
authored andcommitted
Add table quantization type
1 parent 220038e commit 7bda989

6 files changed

Lines changed: 118 additions & 26 deletions

File tree

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,11 @@ Options:
131131
--ef-construction INTEGER hnsw ef-construction
132132
--ef-search INTEGER hnsw ef-search
133133
--quantization-type [none|bit|halfvec]
134-
quantization type for vectors
134+
quantization type for vectors (in index)
135+
--table-quantization-type [none|bit|halfvec]
136+
quantization type for vectors (in table). If
137+
equal to bit, the parameter
138+
quantization_type will be set to bit too.
135139
--custom-case-name TEXT Custom case name i.e. PerformanceCase1536D50K
136140
--custom-case-description TEXT Custom name description
137141
--custom-case-load-timeout INTEGER

vectordb_bench/backend/clients/pgvector/cli.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,17 @@ class PgVectorTypedDict(CommonTypedDict):
8282
click.option(
8383
"--quantization-type",
8484
type=click.Choice(["none", "bit", "halfvec"]),
85-
help="quantization type for vectors",
85+
help="quantization type for vectors (in index)",
86+
required=False,
87+
),
88+
]
89+
table_quantization_type: Annotated[
90+
str | None,
91+
click.option(
92+
"--table-quantization-type",
93+
type=click.Choice(["none", "bit", "halfvec"]),
94+
help="quantization type for vectors (in table). "
95+
"If equal to bit, the parameter quantization_type will be set to bit too.",
8696
required=False,
8797
),
8898
]
@@ -146,6 +156,7 @@ def PgVectorIVFFlat(
146156
lists=parameters["lists"],
147157
probes=parameters["probes"],
148158
quantization_type=parameters["quantization_type"],
159+
table_quantization_type=parameters["table_quantization_type"],
149160
reranking=parameters["reranking"],
150161
reranking_metric=parameters["reranking_metric"],
151162
quantized_fetch_limit=parameters["quantized_fetch_limit"],
@@ -182,6 +193,7 @@ def PgVectorHNSW(
182193
maintenance_work_mem=parameters["maintenance_work_mem"],
183194
max_parallel_workers=parameters["max_parallel_workers"],
184195
quantization_type=parameters["quantization_type"],
196+
table_quantization_type=parameters["table_quantization_type"],
185197
reranking=parameters["reranking"],
186198
reranking_metric=parameters["reranking_metric"],
187199
quantized_fetch_limit=parameters["quantized_fetch_limit"],

vectordb_bench/backend/clients/pgvector/config.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,12 @@ def parse_metric(self) -> str:
8080

8181
if d.get(self.quantization_type) is None:
8282
return d.get("_fallback").get(self.metric_type)
83-
return d.get(self.quantization_type).get(self.metric_type)
83+
metric = d.get(self.quantization_type).get(self.metric_type)
84+
# If using binary quantization for the index, use a bit metric
85+
# no matter what metric was selected for vector or halfvec data
86+
if self.quantization_type == "bit" and metric is None:
87+
return "bit_hamming_ops"
88+
return metric
8489

8590
def parse_metric_fun_op(self) -> LiteralString:
8691
if self.quantization_type == "bit":
@@ -168,21 +173,27 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
168173
maintenance_work_mem: str | None = None
169174
max_parallel_workers: int | None = None
170175
quantization_type: str | None = None
176+
table_quantization_type: str | None
171177
reranking: bool | None = None
172178
quantized_fetch_limit: int | None = None
173179
reranking_metric: str | None = None
174180

175181
def index_param(self) -> PgVectorIndexParam:
176182
index_parameters = {"lists": self.lists}
177-
if self.quantization_type == "none":
178-
self.quantization_type = None
183+
if self.quantization_type == "none" or self.quantization_type is None:
184+
self.quantization_type = "vector"
185+
if self.table_quantization_type == "none" or self.table_quantization_type is None:
186+
self.table_quantization_type = "vector"
187+
if self.table_quantization_type == "bit":
188+
self.quantization_type = "bit"
179189
return {
180190
"metric": self.parse_metric(),
181191
"index_type": self.index.value,
182192
"index_creation_with_options": self._optionally_build_with_options(index_parameters),
183193
"maintenance_work_mem": self.maintenance_work_mem,
184194
"max_parallel_workers": self.max_parallel_workers,
185195
"quantization_type": self.quantization_type,
196+
"table_quantization_type": self.table_quantization_type,
186197
}
187198

188199
def search_param(self) -> PgVectorSearchParam:
@@ -212,21 +223,27 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
212223
maintenance_work_mem: str | None = None
213224
max_parallel_workers: int | None = None
214225
quantization_type: str | None = None
226+
table_quantization_type: str | None
215227
reranking: bool | None = None
216228
quantized_fetch_limit: int | None = None
217229
reranking_metric: str | None = None
218230

219231
def index_param(self) -> PgVectorIndexParam:
220232
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
221-
if self.quantization_type == "none":
222-
self.quantization_type = None
233+
if self.quantization_type == "none" or self.quantization_type is None:
234+
self.quantization_type = "vector"
235+
if self.table_quantization_type == "none" or self.table_quantization_type is None:
236+
self.table_quantization_type = "vector"
237+
if self.table_quantization_type == "bit":
238+
self.quantization_type = "bit"
223239
return {
224240
"metric": self.parse_metric(),
225241
"index_type": self.index.value,
226242
"index_creation_with_options": self._optionally_build_with_options(index_parameters),
227243
"maintenance_work_mem": self.maintenance_work_mem,
228244
"max_parallel_workers": self.max_parallel_workers,
229245
"quantization_type": self.quantization_type,
246+
"table_quantization_type": self.table_quantization_type,
230247
}
231248

232249
def search_param(self) -> PgVectorSearchParam:

vectordb_bench/backend/clients/pgvector/pgvector.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
9494
reranking = self.case_config.search_param()["reranking"]
9595
column_name = (
9696
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
97-
if index_param["quantization_type"] == "bit"
97+
if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit"
9898
else sql.SQL("embedding")
9999
)
100100
search_vector = (
@@ -104,7 +104,8 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
104104
)
105105

106106
# The following sections assume that the quantization_type value matches the quantization function name
107-
if index_param["quantization_type"] is not None:
107+
if index_param["quantization_type"] != index_param["table_quantization_type"]:
108+
# Reranking makes sense only if table quantization is not "bit"
108109
if index_param["quantization_type"] == "bit" and reranking:
109110
# Embeddings needs to be passed to binary_quantize function if quantization_type is bit
110111
search_query = sql.Composed(
@@ -113,7 +114,7 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
113114
"""
114115
SELECT i.id
115116
FROM (
116-
SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
117+
SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance
117118
FROM public.{table_name} {where_clause}
118119
ORDER BY {column_name}::{quantization_type}({dim})
119120
""",
@@ -123,21 +124,25 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
123124
reranking_metric_fun_op=sql.SQL(
124125
self.case_config.search_param()["reranking_metric_fun_op"],
125126
),
127+
search_vector=search_vector,
128+
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
126129
quantization_type=sql.SQL(index_param["quantization_type"]),
127130
dim=sql.Literal(self.dim),
128131
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
129132
),
130133
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
131134
sql.SQL(
132135
"""
133-
{search_vector}
136+
{search_vector}::{quantization_type}({dim})
134137
LIMIT {quantized_fetch_limit}
135138
) i
136139
ORDER BY i.distance
137140
LIMIT %s::int
138141
""",
139142
).format(
140143
search_vector=search_vector,
144+
quantization_type=sql.SQL(index_param["quantization_type"]),
145+
dim=sql.Literal(self.dim),
141146
quantized_fetch_limit=sql.Literal(
142147
self.case_config.search_param()["quantized_fetch_limit"],
143148
),
@@ -160,10 +165,12 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
160165
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
161166
),
162167
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
163-
sql.SQL(" {search_vector} LIMIT %s::int").format(
168+
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
164169
search_vector=search_vector,
170+
quantization_type=sql.SQL(index_param["quantization_type"]),
171+
dim=sql.Literal(self.dim),
165172
),
166-
],
173+
]
167174
)
168175
else:
169176
search_query = sql.Composed(
@@ -175,8 +182,12 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
175182
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
176183
),
177184
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
178-
sql.SQL(" %s::vector LIMIT %s::int"),
179-
],
185+
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
186+
search_vector=search_vector,
187+
quantization_type=sql.SQL(index_param["quantization_type"]),
188+
dim=sql.Literal(self.dim),
189+
),
190+
]
180191
)
181192

182193
return search_query
@@ -323,7 +334,7 @@ def _create_index(self):
323334
)
324335
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
325336

326-
if index_param["quantization_type"] is not None:
337+
if index_param["quantization_type"] != index_param["table_quantization_type"]:
327338
index_create_sql = sql.SQL(
328339
"""
329340
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
@@ -365,14 +376,23 @@ def _create_table(self, dim: int):
365376
assert self.conn is not None, "Connection is not initialized"
366377
assert self.cursor is not None, "Cursor is not initialized"
367378

379+
index_param = self.case_config.index_param()
380+
368381
try:
369382
log.info(f"{self.name} client create table : {self.table_name}")
370383

371384
# create table
372385
self.cursor.execute(
373386
sql.SQL(
374-
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
375-
).format(table_name=sql.Identifier(self.table_name), dim=dim),
387+
"""
388+
CREATE TABLE IF NOT EXISTS public.{table_name}
389+
(id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));
390+
"""
391+
).format(
392+
table_name=sql.Identifier(self.table_name),
393+
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
394+
dim=dim,
395+
)
376396
)
377397
self.cursor.execute(
378398
sql.SQL(
@@ -393,18 +413,41 @@ def insert_embeddings(
393413
assert self.conn is not None, "Connection is not initialized"
394414
assert self.cursor is not None, "Cursor is not initialized"
395415

416+
index_param = self.case_config.index_param()
417+
396418
try:
397419
metadata_arr = np.array(metadata)
398420
embeddings_arr = np.array(embeddings)
399421

400-
with self.cursor.copy(
401-
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
402-
table_name=sql.Identifier(self.table_name),
403-
),
404-
) as copy:
405-
copy.set_types(["bigint", "vector"])
406-
for i, row in enumerate(metadata_arr):
407-
copy.write_row((row, embeddings_arr[i]))
422+
if index_param["table_quantization_type"] == "bit":
423+
with self.cursor.copy(
424+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format(
425+
table_name=sql.Identifier(self.table_name)
426+
)
427+
) as copy:
428+
# Same logic as pgvector binary_quantize
429+
for i, row in enumerate(metadata_arr):
430+
embeddings_bit = ""
431+
for embedding in embeddings_arr[i]:
432+
if embedding > 0:
433+
embeddings_bit += "1"
434+
else:
435+
embeddings_bit += "0"
436+
copy.write_row((str(row), embeddings_bit))
437+
else:
438+
with self.cursor.copy(
439+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
440+
table_name=sql.Identifier(self.table_name)
441+
)
442+
) as copy:
443+
if index_param["table_quantization_type"] == "halfvec":
444+
copy.set_types(["bigint", "halfvec"])
445+
for i, row in enumerate(metadata_arr):
446+
copy.write_row((row, np.float16(embeddings_arr[i])))
447+
else:
448+
copy.set_types(["bigint", "vector"])
449+
for i, row in enumerate(metadata_arr):
450+
copy.write_row((row, embeddings_arr[i]))
408451
self.conn.commit()
409452

410453
if kwargs.get("last_batch"):

vectordb_bench/frontend/config/dbCaseConfigs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,19 @@ class CaseConfigInput(BaseModel):
823823
],
824824
)
825825

826+
CaseConfigParamInput_TableQuantizationType_PgVector = CaseConfigInput(
827+
label=CaseConfigParamType.tableQuantizationType,
828+
inputType=InputType.Option,
829+
inputConfig={
830+
"options": ["none", "bit", "halfvec"],
831+
},
832+
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
833+
in [
834+
IndexType.HNSW.value,
835+
IndexType.IVFFlat.value,
836+
],
837+
)
838+
826839
CaseConfigParamInput_max_parallel_workers_PgVectorRS = CaseConfigInput(
827840
label=CaseConfigParamType.max_parallel_workers,
828841
displayLabel="Max parallel workers",
@@ -1138,6 +1151,7 @@ class CaseConfigInput(BaseModel):
11381151
CaseConfigParamInput_m,
11391152
CaseConfigParamInput_EFConstruction_PgVector,
11401153
CaseConfigParamInput_QuantizationType_PgVector,
1154+
CaseConfigParamInput_TableQuantizationType_PgVector,
11411155
CaseConfigParamInput_maintenance_work_mem_PgVector,
11421156
CaseConfigParamInput_max_parallel_workers_PgVector,
11431157
]
@@ -1149,6 +1163,7 @@ class CaseConfigInput(BaseModel):
11491163
CaseConfigParamInput_Lists_PgVector,
11501164
CaseConfigParamInput_Probes_PgVector,
11511165
CaseConfigParamInput_QuantizationType_PgVector,
1166+
CaseConfigParamInput_TableQuantizationType_PgVector,
11521167
CaseConfigParamInput_maintenance_work_mem_PgVector,
11531168
CaseConfigParamInput_max_parallel_workers_PgVector,
11541169
CaseConfigParamInput_reranking_PgVector,

vectordb_bench/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class CaseConfigParamType(Enum):
4949
probes = "probes"
5050
quantizationType = "quantization_type"
5151
quantizationRatio = "quantization_ratio"
52+
tableQuantizationType = "table_quantization_type"
5253
reranking = "reranking"
5354
rerankingMetric = "reranking_metric"
5455
quantizedFetchLimit = "quantized_fetch_limit"

0 commit comments

Comments
 (0)