|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import itertools |
4 | 3 | from collections import Counter |
| 4 | +from collections.abc import Iterator |
5 | 5 | from typing import Any, cast |
6 | 6 |
|
7 | 7 | import hyperbase.constants as const |
@@ -258,6 +258,37 @@ def _matches_atomic_pattern(edge: Hyperedge, atomic_pattern: Atom) -> bool: |
258 | 258 | # argroles # |
259 | 259 | ############ |
260 | 260 |
|
| 261 | +_MAX_ARGROLE_ITEMS = 10 |
| 262 | + |
| 263 | + |
| 264 | +def _can_match_structurally(edge: Hyperedge, pattern: Hyperedge) -> bool: |
| 265 | + """Cheap pre-filter: can edge possibly match pattern based on structure?""" |
| 266 | + if pattern.atom: |
| 267 | + return _matches_atomic_pattern(edge, cast(Atom, pattern)) |
| 268 | + if pattern.is_fun_pattern(): |
| 269 | + return True |
| 270 | + # non-atomic pattern requires non-atomic edge |
| 271 | + return edge.not_atom |
| 272 | + |
| 273 | + |
| 274 | +def _valid_assignments( |
| 275 | + candidates: list[list[int]], |
| 276 | + pos: int = 0, |
| 277 | + used: set[int] | None = None, |
| 278 | +) -> Iterator[tuple[int, ...]]: |
| 279 | + """Backtracking generator of assignments constrained by candidate sets.""" |
| 280 | + if used is None: |
| 281 | + used = set() |
| 282 | + if pos == len(candidates): |
| 283 | + yield () |
| 284 | + return |
| 285 | + for idx in candidates[pos]: |
| 286 | + if idx not in used: |
| 287 | + used.add(idx) |
| 288 | + for rest in _valid_assignments(candidates, pos + 1, used): |
| 289 | + yield (idx, *rest) |
| 290 | + used.discard(idx) |
| 291 | + |
261 | 292 |
|
262 | 293 | def _match_by_argroles( |
263 | 294 | matcher: Matcher, |
@@ -296,21 +327,38 @@ def _match_by_argroles( |
296 | 327 | else: |
297 | 328 | return [] |
298 | 329 |
|
| 330 | + if len(eitems) > _MAX_ARGROLE_ITEMS: |
| 331 | + raise ValueError( |
| 332 | + f"Edge has {len(eitems)} items for argrole '{argrole}', " |
| 333 | + f"exceeding limit of {_MAX_ARGROLE_ITEMS}" |
| 334 | + ) |
| 335 | + |
| 336 | + # constraint propagation: pre-compute which eitems can match each pitem |
| 337 | + candidates: list[list[int]] = [ |
| 338 | + [j for j in range(len(eitems)) if _can_match_structurally(eitems[j], pitem)] |
| 339 | + for pitem in pitems |
| 340 | + ] |
| 341 | + |
| 342 | + # early exit if any pattern position has zero candidates |
| 343 | + if any(len(c) == 0 for c in candidates): |
| 344 | + if len(curvars) >= min_vars: |
| 345 | + return [curvars] |
| 346 | + else: |
| 347 | + return [] |
| 348 | + |
299 | 349 | result: list[dict[str, Hyperedge]] = [] |
300 | 350 |
|
301 | 351 | if tok_pos: |
302 | 352 | tok_pos_items = [ |
303 | 353 | tok_pos[i] for i, subedge in enumerate(edge) if subedge in eitems |
304 | 354 | ] |
305 | | - tok_pos_perms = tuple(itertools.permutations(tok_pos_items, r=n)) |
306 | 355 |
|
307 | | - for perm_n, perm in enumerate(tuple(itertools.permutations(eitems, r=n))): |
308 | | - if tok_pos: |
309 | | - tok_pos_perm = tok_pos_perms[perm_n] |
| 356 | + for assignment in _valid_assignments(candidates): |
| 357 | + perm = tuple(eitems[j] for j in assignment) |
310 | 358 | perm_result: list[dict[str, Hyperedge]] = [{}] |
311 | 359 | for i, eitem in enumerate(perm): |
312 | 360 | pitem = pitems[i] |
313 | | - tok_pos_item = tok_pos_perm[i] if tok_pos else None |
| 361 | + tok_pos_item = tok_pos_items[assignment[i]] if tok_pos else None |
314 | 362 | item_result: list[dict[str, Hyperedge]] = [] |
315 | 363 | for variables in perm_result: |
316 | 364 | item_result += matcher.match( |
|
0 commit comments