1515from typing import (
1616 Any ,
1717 Generic ,
18+ Iterator ,
1819 Self ,
1920 TypeVar ,
2021)
22+ import unittest .mock
2123
2224import attr
2325from furl import (
@@ -235,23 +237,30 @@ def prepare_query(self, skip_field_paths: tuple[FieldPath] = ()) -> Query:
235237 Converts the given filters into an Elasticsearch DSL Query object.
236238 """
237239 filter_list = []
240+ plugin = self .plugin
241+ source_id_field_name = plugin .special_fields .source_id .name
242+ source_id_field_path = plugin .field_mapping [source_id_field_name ]
238243 for field_path , filter in self .prepared_filters .items ():
239244 if field_path not in skip_field_paths :
245+ values : Sequence [PrimitiveJSON ] | ToJsonTemplate
240246 operator , values = one (filter .items ())
247+ original_values = values
241248 # Note that `is_not` is only used internally (for filtering by
242249 # inaccessible sources)
243250 if operator in ('is' , 'is_not' ):
244251 field_type = self .service .field_type (self .catalog , field_path )
252+ if field_path == source_id_field_path :
253+ values = ToJsonTemplate (param_name = source_id_field_name , value = values )
245254 if isinstance (field_type , Nested ):
246255 term_queries = []
247- for nested_field , nested_value in one (values ).items ():
256+ for nested_field , nested_value in one (original_values ).items ():
248257 nested_body = {dotted (field_path , nested_field , 'keyword' ): nested_value }
249258 term_queries .append (Q ('term' , ** nested_body ))
250259 query = Q ('nested' , path = dotted (field_path ), query = Q ('bool' , must = term_queries ))
251260 else :
252261 query = Q ('terms' , ** {dotted (field_path , 'keyword' ): values })
253262 translated_none = field_type .to_index (None )
254- if translated_none in values :
263+ if translated_none in original_values :
255264 # Note that at this point None values in filters have already
256265 # been translated e.g. {'is': ['~null']} and if the filter has a
257266 # None our query needs to find fields with None values as well
@@ -668,6 +677,29 @@ def to_source(self) -> AnyJSON:
668677 raise NotImplementedError
669678
670679
680+ @attr .s (frozen = True , auto_attribs = True , kw_only = True )
681+ class ToJsonTemplate (Template ):
682+
683+ def to_source (self ):
684+ return RawStr ('{{#toJson}}' + self .param_name + '{{/toJson}}' )
685+
686+
687+ class RawStr (str ):
688+ """
689+ Instances of this class will not be surrounded by quotes when encoded as
690+ JSON using a :class:`TemplateSearchJSONEncoder`.
691+ """
692+
693+
694+ _original = json .encoder .py_encode_basestring_ascii
695+
696+
697+ def _encode_basestring_ascii (s : str ) -> str :
698+ result = _original (s )
699+ assert result [0 ] == result [- 1 ] == '"' , result
700+ return result [1 :- 1 ] if isinstance (s , RawStr ) else result
701+
702+
671703class TemplateSearchJSONEncoder (json .JSONEncoder ):
672704
673705 def __init__ (self , ** kwargs ):
@@ -686,6 +718,11 @@ def default(self, obj):
686718 else :
687719 return super ().default (obj )
688720
721+ def iterencode (self , o : AnyJSON , _one_shot : bool = False ) -> Iterator [str ]:
722+ with unittest .mock .patch ('json.encoder.encode_basestring_ascii' ,
723+ wraps = _encode_basestring_ascii ):
724+ return super ().iterencode (o , _one_shot = _one_shot )
725+
689726
690727class TemplateSearch (Search ):
691728
0 commit comments