@@ -86,17 +86,19 @@ def test_rql_generation(self):
8686 q4 = session .query (object_type = Dto ).vector_search ("VectorField" , "aaaa==" )
8787 self .assertEqual ("from 'Dtoes' where vector.search(VectorField, $p0)" , q4 ._to_string ())
8888
89- q5 = session .query (object_type = Dto ).vector_search_text_i8 ("TextField" , "aaaa" )
89+ q5 = session .query (object_type = Dto ).vector_search_text ("TextField" , "aaaa" , target_quantization = VectorEmbeddingType . INT8 )
9090 self .assertEqual ("from 'Dtoes' where vector.search(embedding.text_i8(TextField), $p0)" , q5 ._to_string ())
9191
9292 q6 = session .query (object_type = Dto ).vector_search_i8 ("EmbeddingField" , [2 , 3 ], 0.65 )
9393 self .assertEqual (
9494 "from 'Dtoes' where vector.search(embedding.i8(EmbeddingField), $p0, 0.65, null)" , q6 ._to_string ()
9595 )
9696
97- q7 = session .query (object_type = Dto ).vector_search_text_i8 ("TextField" , "aaaa" )
97+ q7 = session .query (object_type = Dto ).vector_search_text ("TextField" , "aaaa" , target_quantization = VectorEmbeddingType . INT8 )
9898 self .assertEqual ("from 'Dtoes' where vector.search(embedding.text_i8(TextField), $p0)" , q7 ._to_string ())
9999
100+ # q8 = session.query(object_type=Dto).vector_search_with_field()
101+
100102 def test_rql_generation_2 (self ):
101103 with self .store .open_session () as session :
102104
@@ -115,12 +117,13 @@ def test_rql_generation_2(self):
115117 "from 'Dtoes' where vector.search(embedding.i8(EmbeddingField), $p0, 0.65, null)" , q1 ._to_string ()
116118 )
117119
118- q2 = session .query (object_type = Dto ).vector_search_f32_i8 ("EmbeddingField" , [2.5 , 3.3 ], 0.65 )
120+ q2 = session .query (object_type = Dto ).vector_search ("EmbeddingField" , [2.5 , 3.3 ], 0.65 , target_quantization = VectorEmbeddingType .INT8
121+ )
119122 self .assertEqual (
120123 "from 'Dtoes' where vector.search(embedding.f32_i8(EmbeddingField), $p0, 0.65, null)" , q2 ._to_string ()
121124 )
122125
123- q3 = session .query (object_type = Dto ).vector_search_f32_i8 ("EmbeddingField" , "abcd==" , 0.75 )
126+ q3 = session .query (object_type = Dto ).vector_search ("EmbeddingField" , "abcd==" , 0.75 , target_quantization = VectorEmbeddingType . INT8 )
124127 self .assertEqual (
125128 "from 'Dtoes' where vector.search(embedding.f32_i8(EmbeddingField), $p0, 0.75, null)" , q3 ._to_string ()
126129 )
@@ -144,6 +147,72 @@ def test_rql_generation_2(self):
144147 )
145148 self .assertEqual ("from 'Dtoes' where exact(vector.search(EmbeddingBase64, $p0, null, 25))" , q8 ._to_string ())
146149
150+ def test_rql_generation_3 (self ):
151+ with self .store .open_session () as session :
152+ # forDocument - text/field
153+ q1 = session .query (object_type = Dto ).vector_search_with_field_for_document ("VectorField" , "docs/1-A" )
154+ self .assertEqual ("from 'Dtoes' where vector.search(VectorField, embedding.forDoc($p0))" , q1 ._to_string ())
155+
156+ q2 = session .query (object_type = Dto ).vector_search_text_for_document ("VectorField" , "docs/1-A" , target_quantization = VectorEmbeddingType .INT8 )
157+ self .assertEqual (
158+ "from 'Dtoes' where vector.search(embedding.text_i8(VectorField), embedding.forDoc($p0))" , q2 ._to_string ()
159+ )
160+
161+ # withField
162+ q3 = session .query (object_type = Dto ).vector_search_with_field ("VectorField" , [0.1 , 0.2 , 0.3 ])
163+ self .assertEqual ("from 'Dtoes' where vector.search(VectorField, $p0)" , q3 ._to_string ())
164+
165+ q4 = session .query (object_type = Dto ).vector_search_with_text_field ("VectorField" , "hello" )
166+ self .assertEqual ("from 'Dtoes' where vector.search(VectorField, $p0)" , q4 ._to_string ())
167+
168+ q5 = session .query (object_type = Dto ).vector_search_with_i8_field ("VectorField" , [1 , 2 , 3 ])
169+ self .assertEqual ("from 'Dtoes' where vector.search(VectorField, $p0)" , q5 ._to_string ())
170+
171+ q6 = session .query (object_type = Dto ).vector_search_with_i1_field ("VectorField" , [0 , 1 , 0 ])
172+ self .assertEqual ("from 'Dtoes' where vector.search(VectorField, $p0)" , q6 ._to_string ())
173+
174+ # with base64
175+ q7 = session .query (object_type = Dto ).vector_search_with_base64 ("VectorField" , "abcd==" )
176+ self .assertEqual ("from 'Dtoes' where vector.search(VectorField, $p0)" , q7 ._to_string ())
177+
178+ q8 = session .query (object_type = Dto ).vector_search_with_base64_i8 ("VectorField" , "abcd==" )
179+ self .assertEqual ("from 'Dtoes' where vector.search(embedding.i8(VectorField), $p0)" , q8 ._to_string ())
180+
181+ q9 = session .query (object_type = Dto ).vector_search_with_base64_i1 ("VectorField" , "abcd==" )
182+ self .assertEqual ("from 'Dtoes' where vector.search(embedding.i1(VectorField), $p0)" , q9 ._to_string ())
183+
184+ # ability to search in base64
185+ q10 = session .query (object_type = Dto ).vector_search ("VectorField" , "abcd==" )
186+ self .assertEqual ("from 'Dtoes' where vector.search(VectorField, $p0)" , q10 ._to_string ())
187+
188+ q11 = session .query (object_type = Dto ).vector_search_i8 ("VectorField" , "abcd==" )
189+ self .assertEqual ("from 'Dtoes' where vector.search(embedding.i8(VectorField), $p0)" , q11 ._to_string ())
190+
191+ q12 = session .query (object_type = Dto ).vector_search_i1 ("VectorField" , "abcd==" )
192+ self .assertEqual ("from 'Dtoes' where vector.search(embedding.i1(VectorField), $p0)" , q12 ._to_string ())
193+
194+ q13 = session .query (object_type = Dto ).vector_search_with_field ("VectorField" , "abcd==" )
195+ self .assertEqual ("from 'Dtoes' where vector.search(VectorField, $p0)" , q13 ._to_string ())
196+
197+ q14 = session .query (object_type = Dto ).vector_search_with_i8_field ("VectorField" , "abcd==" )
198+ self .assertEqual ("from 'Dtoes' where vector.search(VectorField, $p0)" , q14 ._to_string ())
199+
200+ q15 = session .query (object_type = Dto ).vector_search_with_i1_field ("VectorField" , "abcd==" )
201+ self .assertEqual ("from 'Dtoes' where vector.search(VectorField, $p0)" , q15 ._to_string ())
202+
203+ # embeddingTaskIdentifier
204+ q16 = session .query (object_type = Dto ).vector_search_text ("VectorField" , "hello" , embedding_generation_task_identifier = "my-ai-task" )
205+ self .assertEqual (
206+ "from 'Dtoes' where vector.search(embedding.text(VectorField, ai.task('my-ai-task')), $p0)" ,
207+ q16 ._to_string (),
208+ )
209+
210+ q17 = session .query (object_type = Dto ).vector_search_text_for_document ("VectorField" , "hello" , embedding_generation_task_identifier = "my-ai-task" )
211+ self .assertEqual (
212+ "from 'Dtoes' where vector.search(embedding.text(VectorField, ai.task('my-ai-task')), embedding.forDoc($p0))" ,
213+ q17 ._to_string (),
214+ )
215+
147216 def test_embedding_dimensions_check (self ):
148217 with self .store .open_session () as session :
149218 dto1 = Dto (embedding_singles = [0.5 , - 1.0 ])
0 commit comments