@@ -94,7 +94,7 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
9494 reranking = self .case_config .search_param ()["reranking" ]
9595 column_name = (
9696 sql .SQL ("binary_quantize({0})" ).format (sql .Identifier ("embedding" ))
97- if index_param ["quantization_type" ] == "bit"
97+ if index_param ["quantization_type" ] == "bit" and index_param [ "table_quantization_type" ] != "bit"
9898 else sql .SQL ("embedding" )
9999 )
100100 search_vector = (
@@ -104,7 +104,8 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
104104 )
105105
106106 # The following sections assume that the quantization_type value matches the quantization function name
107- if index_param ["quantization_type" ] is not None :
107+ if index_param ["quantization_type" ] != index_param ["table_quantization_type" ]:
108+ # Reranking makes sense only if table quantization is not "bit"
108109 if index_param ["quantization_type" ] == "bit" and reranking :
109110 # Embeddings needs to be passed to binary_quantize function if quantization_type is bit
110111 search_query = sql .Composed (
@@ -113,7 +114,7 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
113114 """
114115 SELECT i.id
115116 FROM (
116- SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
117+ SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance
117118 FROM public.{table_name} {where_clause}
118119 ORDER BY {column_name}::{quantization_type}({dim})
119120 """ ,
@@ -123,21 +124,25 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
123124 reranking_metric_fun_op = sql .SQL (
124125 self .case_config .search_param ()["reranking_metric_fun_op" ],
125126 ),
127+ search_vector = search_vector ,
128+ table_quantization_type = sql .SQL (index_param ["table_quantization_type" ]),
126129 quantization_type = sql .SQL (index_param ["quantization_type" ]),
127130 dim = sql .Literal (self .dim ),
128131 where_clause = sql .SQL ("WHERE id >= %s" ) if filtered else sql .SQL ("" ),
129132 ),
130133 sql .SQL (self .case_config .search_param ()["metric_fun_op" ]),
131134 sql .SQL (
132135 """
133- {search_vector}
136+ {search_vector}::{quantization_type}({dim})
134137 LIMIT {quantized_fetch_limit}
135138 ) i
136139 ORDER BY i.distance
137140 LIMIT %s::int
138141 """ ,
139142 ).format (
140143 search_vector = search_vector ,
144+ quantization_type = sql .SQL (index_param ["quantization_type" ]),
145+ dim = sql .Literal (self .dim ),
141146 quantized_fetch_limit = sql .Literal (
142147 self .case_config .search_param ()["quantized_fetch_limit" ],
143148 ),
@@ -160,10 +165,12 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
160165 where_clause = sql .SQL ("WHERE id >= %s" ) if filtered else sql .SQL ("" ),
161166 ),
162167 sql .SQL (self .case_config .search_param ()["metric_fun_op" ]),
163- sql .SQL (" {search_vector} LIMIT %s::int" ).format (
168+ sql .SQL (" {search_vector}::{quantization_type}({dim}) LIMIT %s::int" ).format (
164169 search_vector = search_vector ,
170+ quantization_type = sql .SQL (index_param ["quantization_type" ]),
171+ dim = sql .Literal (self .dim ),
165172 ),
166- ],
173+ ]
167174 )
168175 else :
169176 search_query = sql .Composed (
@@ -175,8 +182,12 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
175182 where_clause = sql .SQL ("WHERE id >= %s" ) if filtered else sql .SQL ("" ),
176183 ),
177184 sql .SQL (self .case_config .search_param ()["metric_fun_op" ]),
178- sql .SQL (" %s::vector LIMIT %s::int" ),
179- ],
185+ sql .SQL (" {search_vector}::{quantization_type}({dim}) LIMIT %s::int" ).format (
186+ search_vector = search_vector ,
187+ quantization_type = sql .SQL (index_param ["quantization_type" ]),
188+ dim = sql .Literal (self .dim ),
189+ ),
190+ ]
180191 )
181192
182193 return search_query
@@ -323,7 +334,7 @@ def _create_index(self):
323334 )
324335 with_clause = sql .SQL ("WITH ({});" ).format (sql .SQL (", " ).join (options )) if any (options ) else sql .Composed (())
325336
326- if index_param ["quantization_type" ] is not None :
337+ if index_param ["quantization_type" ] != index_param [ "table_quantization_type" ] :
327338 index_create_sql = sql .SQL (
328339 """
329340 CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
@@ -365,14 +376,23 @@ def _create_table(self, dim: int):
365376 assert self .conn is not None , "Connection is not initialized"
366377 assert self .cursor is not None , "Cursor is not initialized"
367378
379+ index_param = self .case_config .index_param ()
380+
368381 try :
369382 log .info (f"{ self .name } client create table : { self .table_name } " )
370383
371384 # create table
372385 self .cursor .execute (
373386 sql .SQL (
374- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" ,
375- ).format (table_name = sql .Identifier (self .table_name ), dim = dim ),
387+ """
388+ CREATE TABLE IF NOT EXISTS public.{table_name}
389+ (id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));
390+ """
391+ ).format (
392+ table_name = sql .Identifier (self .table_name ),
393+ table_quantization_type = sql .SQL (index_param ["table_quantization_type" ]),
394+ dim = dim ,
395+ )
376396 )
377397 self .cursor .execute (
378398 sql .SQL (
@@ -393,18 +413,41 @@ def insert_embeddings(
393413 assert self .conn is not None , "Connection is not initialized"
394414 assert self .cursor is not None , "Cursor is not initialized"
395415
416+ index_param = self .case_config .index_param ()
417+
396418 try :
397419 metadata_arr = np .array (metadata )
398420 embeddings_arr = np .array (embeddings )
399421
400- with self .cursor .copy (
401- sql .SQL ("COPY public.{table_name} FROM STDIN (FORMAT BINARY)" ).format (
402- table_name = sql .Identifier (self .table_name ),
403- ),
404- ) as copy :
405- copy .set_types (["bigint" , "vector" ])
406- for i , row in enumerate (metadata_arr ):
407- copy .write_row ((row , embeddings_arr [i ]))
422+ if index_param ["table_quantization_type" ] == "bit" :
423+ with self .cursor .copy (
424+ sql .SQL ("COPY public.{table_name} FROM STDIN (FORMAT TEXT)" ).format (
425+ table_name = sql .Identifier (self .table_name )
426+ )
427+ ) as copy :
428+ # Same logic as pgvector binary_quantize
429+ for i , row in enumerate (metadata_arr ):
430+ embeddings_bit = ""
431+ for embedding in embeddings_arr [i ]:
432+ if embedding > 0 :
433+ embeddings_bit += "1"
434+ else :
435+ embeddings_bit += "0"
436+ copy .write_row ((str (row ), embeddings_bit ))
437+ else :
438+ with self .cursor .copy (
439+ sql .SQL ("COPY public.{table_name} FROM STDIN (FORMAT BINARY)" ).format (
440+ table_name = sql .Identifier (self .table_name )
441+ )
442+ ) as copy :
443+ if index_param ["table_quantization_type" ] == "halfvec" :
444+ copy .set_types (["bigint" , "halfvec" ])
445+ for i , row in enumerate (metadata_arr ):
446+ copy .write_row ((row , np .float16 (embeddings_arr [i ])))
447+ else :
448+ copy .set_types (["bigint" , "vector" ])
449+ for i , row in enumerate (metadata_arr ):
450+ copy .write_row ((row , embeddings_arr [i ]))
408451 self .conn .commit ()
409452
410453 if kwargs .get ("last_batch" ):
0 commit comments