Skip to content

Commit 892bad7

Browse files
authored
Merge pull request #1296 from VisLab/enhance_search
Implemented Phase II object-based search enhancements #1293
2 parents 94dec8d + 45dbef1 commit 892bad7

7 files changed

Lines changed: 410 additions & 48 deletions

File tree

examples/validate_bids_dataset_with_libraries.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"\n",
4545
"## Now validate URLs\n",
4646
"print(\"\\nNow validating with schema URLs.\")\n",
47-
"base_version = \"8.3.0\"\n",
47+
"base_version = \"8.4.0\"\n",
4848
"library1_url = (\n",
4949
" \"https://raw.githubusercontent.com/hed-standard/hed-schemas/main/\"\n",
5050
" + \"library_schemas/score/hedxml/HED_score_2.0.0.xml\"\n",

hed/models/hed_group.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,12 @@ def __eq__(self, other):
434434

435435
def find_tags(self, search_tags, recursive=False, include_groups=2) -> list:
436436
"""Find the base tags and their containing groups.
437-
This searches by short_base_tag, ignoring any ancestors or extensions/values.
437+
438+
Comparison property: ``short_base_tag`` (schema short name without any extension or value).
439+
Rationale: callers pass bare tag names such as ``"Event"`` or ``"Def"`` and must
440+
match regardless of any extension or value the tag carries in the source string.
441+
Using ``short_base_tag`` strips the extension/value so ``"Def/MyDef"`` is found
442+
by searching for ``"Def"``.
438443
439444
Parameters:
440445
search_tags (container): A container of short_base_tags to locate.
@@ -462,11 +467,16 @@ def find_tags(self, search_tags, recursive=False, include_groups=2) -> list:
462467
return found_tags
463468

464469
def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2) -> list:
465-
"""Find the tags and their containing groups.
466-
467-
This searches tag.short_tag.casefold(), with an implicit wildcard on the end.
470+
"""Find tags whose short form starts with a given prefix (implicit trailing wildcard).
468471
469-
e.g. "Eve" will find Event, but not Sensory-event.
472+
Comparison property: ``short_tag`` (schema short name *including* any extension or value).
473+
Rationale: the query is a prefix such as ``"Def/"`` or ``"Eve"``; the match must cover
474+
the extension/value as well so that ``"Def/MyDef"`` is found by ``"Def/"`` but not by
475+
an unrelated tag that merely shares the same base. ``short_tag`` is used (not
476+
``short_base_tag``) so that value-bearing tags like ``"Duration/3 s"`` can be matched
477+
by a prefix query such as ``"Duration/"``.
478+
Note: prefix matching is anchored to the start of ``short_tag`` only, so ``"Eve"``
479+
finds ``"Event"`` but not ``"Sensory-event"``.
470480
471481
Parameters:
472482
search_tags (container): A container of the starts of short tags to search.
@@ -499,15 +509,27 @@ def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2) ->
499509
return found_tags
500510

501511
def find_exact_tags(self, exact_tags, recursive=False, include_groups=1) -> list:
502-
"""Find the given tags. This will only find complete matches, any extension or value must also match.
512+
"""Find tags that match exactly, including any extension or value.
513+
514+
Comparison property: ``HedTag.__eq__`` which compares ``short_tag.casefold()``
515+
(falling back to ``org_tag.casefold()`` for unrecognised tags).
516+
Rationale: callers pass a slash-path string such as ``"def/mydef"`` and need an
517+
exact full-path match — the extension/value is part of the identity (``"Def/Foo"``
518+
must not match ``"Def/Bar"``). Because ``HedTag.__str__`` returns ``short_tag`` when
519+
the tag is schema-identified, a tag written in long form in the source HED string
520+
(e.g. ``"Event/Sensory-event"``) will still be found by a short-form query
521+
(``"Sensory-event"``); the schema normalises them to the same ``short_tag``.
522+
Unrecognised tags fall back to a case-insensitive comparison of the original text.
503523
504524
Parameters:
505-
exact_tags (list of HedTag): A container of tags to locate.
525+
exact_tags (list of str or HedTag): Tags to locate; each is compared via
526+
``HedTag.__eq__``, which accepts both ``str`` and ``HedTag`` operands.
506527
recursive (bool): If true, also check subgroups.
507-
include_groups (bool): 0, 1 or 2.
528+
include_groups (0, 1 or 2):
508529
If 0: Return only tags
509530
If 1: Return only groups
510531
If 2 or any other value: Return both
532+
511533
Returns:
512534
list: A list of tuples. The contents depend on the values of the include_group.
513535
"""
@@ -564,12 +586,21 @@ def _get_def_tags_from_group(group):
564586
return def_tags
565587

566588
def find_tags_with_term(self, term, recursive=False, include_groups=2) -> list:
567-
"""Find any tags that contain the given term.
568-
569-
Note: This can only find identified tags.
589+
"""Find tags whose schema ancestry includes the given term.
590+
591+
Comparison property: ``tag_terms`` — a tuple of all path components in the tag's
592+
long-form schema path, all casefolded (e.g. ``("event", "sensory-event")`` for the
593+
``Sensory-event`` tag).
594+
Rationale: this implements HED's *ancestor search* — a bare query term such as
595+
``"Event"`` must match not only the ``Event`` tag itself but also every descendant
596+
(``Sensory-event``, ``Agent-action``, etc.) because those descendants inherit the
597+
``Event`` parent. ``tag_terms`` encodes the full ancestry, so membership testing
598+
(``term in tag.tag_terms``) handles all descendants in O(k) time where k is the
599+
schema depth. This requires a schema-identified tag; unidentified tags have an
600+
empty ``tag_terms`` tuple and will not be found.
570601
571602
Parameters:
572-
term (str): A single term to search for.
603+
term (str): A single term to search for (compared case-insensitively).
573604
recursive (bool): If true, recursively check subgroups.
574605
include_groups (0, 1 or 2): Controls return values
575606
If 0: Return only tags.
@@ -579,6 +610,7 @@ def find_tags_with_term(self, term, recursive=False, include_groups=2) -> list:
579610
Returns:
580611
list: A list of tuples. The contents depend on the values of the include_group.
581612
"""
613+
# Note: unidentified tags (tag_terms == ()) are silently skipped.
582614
found_tags = []
583615
if recursive:
584616
tags = self.get_all_tags()

hed/models/query_expressions.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def merge_and_groups(groups1, groups2):
120120
list: Groups in both lists narrowed down to results where none of the children overlap.
121121
"""
122122
return_list = []
123+
seen = set()
123124
for group in groups1:
124125
for other_group in groups2:
125126
if group.group is other_group.group:
@@ -128,16 +129,9 @@ def merge_and_groups(groups1, groups2):
128129
continue
129130
# Merge the two groups' children into one new result, now that we've verified they're unique
130131
merged_result = group.merge_and_result(other_group)
131-
132-
dont_add = False
133-
# This is trash and slow
134-
for finalized_value in return_list:
135-
if merged_result.has_same_children(finalized_value):
136-
dont_add = True
137-
break
138-
if dont_add:
139-
continue
140-
return_list.append(merged_result)
132+
if merged_result not in seen:
133+
seen.add(merged_result)
134+
return_list.append(merged_result)
141135

142136
return return_list
143137

@@ -190,9 +184,7 @@ def handle_expr(self, hed_group, exact=False):
190184
for child in group.groups():
191185
groups_found.append((child, group))
192186

193-
# Wildcards are only found in containing groups. I believe this is correct.
194-
# todo: Is this code still needed for this kind of wildcard? We already are registering every group, just not
195-
# every group at every level.
187+
# Wildcards are only found in containing groups — not propagated to every parent level.
196188
all_found_groups = [SearchResult(group, tag) for tag, group in groups_found]
197189
return all_found_groups
198190

@@ -217,16 +209,9 @@ def handle_expr(self, hed_group, exact=False):
217209
groups1 = self.left.handle_expr(hed_group, exact=exact)
218210
# Don't early out as we need to gather all groups in case children appear more than once etc
219211
groups2 = self.right.handle_expr(hed_group, exact=exact)
220-
# todo: optimize this eventually
221-
# Filter out duplicates
222-
duplicates = []
223-
for group in groups1:
224-
for other_group in groups2:
225-
if group.has_same_children(other_group):
226-
duplicates.append(group)
227-
228-
groups1 = [group for group in groups1 if not any(other_group is group for other_group in duplicates)]
229-
212+
# Filter out results from groups1 that are already represented in groups2
213+
groups2_set = set(groups2)
214+
groups1 = [g for g in groups1 if g not in groups2_set]
230215
return groups1 + groups2
231216

232217
def __str__(self):
@@ -258,16 +243,9 @@ def handle_expr(self, hed_group, exact=False):
258243
259244
"""
260245
found_groups = self.right.handle_expr(hed_group, exact=exact)
261-
262-
# Todo: this may need more thought with respects to wildcards and negation
263-
# negated_groups = [group for group in hed_group.get_all_groups() if group not in groups]
264-
# This simpler version works on python >= 3.10
265-
# negated_groups = [SearchResult(group, []) for group in hed_group.get_all_groups() if group not in groups]
266-
# Python 3.7/8 compatible version.
246+
found_group_ids = {id(found_group.group) for found_group in found_groups}
267247
negated_groups = [
268-
SearchResult(group, [])
269-
for group in hed_group.get_all_groups()
270-
if not any(group is found_group.group for found_group in found_groups)
248+
SearchResult(group, []) for group in hed_group.get_all_groups() if id(group) not in found_group_ids
271249
]
272250

273251
return negated_groups

hed/models/query_handler.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _handle_negation(self):
137137
next_token = self._next_token_is([Token.LogicalNegation])
138138
if next_token == Token.LogicalNegation:
139139
interior = self._handle_grouping_op()
140-
if "?" in str(interior):
140+
if self._expr_has_wildcard(interior):
141141
raise HedQueryError(
142142
"Cannot negate wildcards, or expressions that contain wildcards."
143143
"Use {required_expression : optional_expression}."
@@ -147,6 +147,19 @@ def _handle_negation(self):
147147
else:
148148
return self._handle_grouping_op()
149149

150+
@staticmethod
151+
def _expr_has_wildcard(expr):
152+
"""Return True if the expression tree contains any wildcard node."""
153+
if expr is None:
154+
return False
155+
if isinstance(expr, ExpressionWildcardNew):
156+
return True
157+
if QueryHandler._expr_has_wildcard(expr.left):
158+
return True
159+
if QueryHandler._expr_has_wildcard(expr.right):
160+
return True
161+
return False
162+
150163
def _handle_grouping_op(self):
151164
next_token = self._next_token_is([Token.LogicalGroup, Token.DescendantGroup, Token.ExactMatch])
152165
if next_token == Token.LogicalGroup:

hed/models/query_util.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def merge_and_result(self, other):
5151
new_children.append(child)
5252
new_children.sort(key=str)
5353

54-
if self.group != other.group:
54+
if self.group is not other.group:
5555
raise ValueError("Internal error")
5656
return SearchResult(self.group, new_children)
5757

@@ -64,14 +64,26 @@ def has_same_children(self, other):
6464
Returns:
6565
bool: True if both results have the same group and identical children.
6666
"""
67-
if self.group != other.group:
67+
if self.group is not other.group:
6868
return False
6969

7070
if len(self.children) != len(other.children):
7171
return False
7272

7373
return all(child is child2 for child, child2 in zip(self.children, other.children, strict=False))
7474

75+
def __eq__(self, other):
76+
if not isinstance(other, SearchResult):
77+
return NotImplemented
78+
return (
79+
self.group is other.group
80+
and len(self.children) == len(other.children)
81+
and all(c is c2 for c, c2 in zip(self.children, other.children, strict=False))
82+
)
83+
84+
def __hash__(self):
85+
return hash((id(self.group), tuple(id(c) for c in self.children)))
86+
7587
# Backward compatibility alias
7688
has_same_tags = has_same_children
7789

tests/models/test_query_handler.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,3 +747,130 @@ def test_match_mode_constants(self):
747747
self.assertNotEqual(Expression.MATCH_TERM, Expression.MATCH_EXACT)
748748
self.assertNotEqual(Expression.MATCH_EXACT, Expression.MATCH_WILDCARD)
749749
self.assertNotEqual(Expression.MATCH_TERM, Expression.MATCH_WILDCARD)
750+
751+
def test_ancestor_term_search(self):
752+
"""Bare term uses MATCH_TERM, which searches tag_terms (long-form path components).
753+
754+
'Event' (no slash, no asterisk) matches any tag whose schema ancestry includes 'event',
755+
including the tag 'Event' itself and 'Sensory-event' whose long-form path is
756+
'Event/Sensory-event'.
757+
"""
758+
test_strings = {
759+
"Event": True,
760+
"Sensory-event": True, # 'event' is an ancestor component via tag_terms
761+
"Clear-throat": False, # under Action hierarchy, not Event
762+
}
763+
self.base_test("Event", test_strings)
764+
765+
def test_term_search_case_insensitive(self):
766+
"""MATCH_TERM searches are case-insensitive via casefold()."""
767+
test_strings = {
768+
"Event": True,
769+
"Sensory-event": True,
770+
"Clear-throat": False,
771+
}
772+
self.base_test("event", test_strings)
773+
self.base_test("EVENT", test_strings)
774+
775+
def test_wildcard_matches_short_tag_prefix(self):
776+
"""MATCH_WILDCARD (asterisk) prefix-matches on short_tag, not on the long-form path.
777+
778+
'Sensory*' matches 'Sensory-event' (short_tag starts with 'sensory').
779+
'Event*' matches 'Event' but NOT 'Sensory-event' because short_tag='Sensory-event'
780+
does not start with 'event'.
781+
"""
782+
test_strings = {
783+
"Sensory-event": True,
784+
"Event": False, # short_tag 'Event' does not start with 'sensory'
785+
}
786+
self.base_test("Sensory*", test_strings)
787+
788+
test_strings = {
789+
"Event": True,
790+
"Sensory-event": False, # short_tag 'Sensory-event' does not start with 'event'
791+
}
792+
self.base_test("Event*", test_strings)
793+
794+
def test_slash_path_query_uses_short_tag_comparison(self):
795+
"""A slash-path query uses MATCH_EXACT, comparing the query string against short_tag.
796+
797+
'Event/Sensory-event' does NOT match the stored tag 'Sensory-event' because
798+
MATCH_EXACT compares the full query path 'event/sensory-event' against
799+
short_tag.casefold() = 'sensory-event'.
800+
801+
Value-extension tags (e.g. 'Def/Name') are an exception: their short_tag already
802+
includes the slash and value, so 'Def/Name' does match.
803+
"""
804+
test_strings = {
805+
"Sensory-event": False, # short_tag is 'Sensory-event', not 'Event/Sensory-event'
806+
"Event": False,
807+
}
808+
self.base_test("Event/Sensory-event", test_strings)
809+
810+
# --- Negation wildcard check (Phase C) ---
811+
812+
def test_negation_wildcard_raises(self):
813+
"""A wildcard token anywhere inside a negation must raise HedQueryError.
814+
815+
The check uses an expression-tree walk (_expr_has_wildcard) rather than
816+
a string search on str(interior), so it handles wildcards at any nesting
817+
depth inside the negation.
818+
"""
819+
invalid_queries = [
820+
"~?", # bare wildcard negated
821+
"~??", # tag-only wildcard negated
822+
"~???", # group wildcard negated
823+
"~(A, ?)", # wildcard inside a logical group that is negated
824+
"~(A, ??)", # tag wildcard inside negated group
825+
"~(A && ?)", # wildcard as right operand of AND inside negation
826+
"~[A, ?]", # wildcard inside a descendant group that is negated
827+
]
828+
for query in invalid_queries:
829+
with self.subTest(query=query):
830+
with self.assertRaises(HedQueryError):
831+
QueryHandler(query)
832+
833+
def test_negation_no_wildcard_valid(self):
834+
"""Negation without wildcards must parse without error."""
835+
valid_queries = [
836+
"~A",
837+
"~(A)",
838+
"~(A && B)",
839+
"~(A || B)",
840+
"~[A]",
841+
"~[A, B]",
842+
]
843+
for query in valid_queries:
844+
with self.subTest(query=query):
845+
qh = QueryHandler(query)
846+
self.assertIsNotNone(qh.tree)
847+
848+
# --- OR dedup (Phase A) ---
849+
850+
def test_or_no_duplicate_results(self):
851+
"""A || A must not return the same SearchResult more than once.
852+
853+
With the set-based dedup introduced in Phase A, results present in both
854+
sides of an OR are filtered before the lists are concatenated.
855+
"""
856+
expression = QueryHandler("Event || Event")
857+
hed_string = HedString("Event", self.hed_schema)
858+
results = expression.search(hed_string)
859+
for i, r1 in enumerate(results):
860+
for j, r2 in enumerate(results):
861+
if i != j:
862+
self.assertFalse(
863+
r1 == r2,
864+
f"Duplicate SearchResult at positions {i} and {j}: {r1}",
865+
)
866+
867+
def test_or_dedup_count(self):
868+
"""A || A should return the same number of results as A alone."""
869+
hed_string = HedString("Event", self.hed_schema)
870+
single_results = QueryHandler("Event").search(hed_string)
871+
or_results = QueryHandler("Event || Event").search(hed_string)
872+
self.assertEqual(
873+
len(single_results),
874+
len(or_results),
875+
"Event || Event should match the same number of groups as Event alone",
876+
)

0 commit comments

Comments
 (0)