|
8 | 8 | from collections.abc import Callable, Iterable, Mapping |
9 | 9 | from typing import Annotated, Concatenate |
10 | 10 |
|
11 | | -import tree |
12 | | - |
13 | 11 | from effectful.ops.types import Annotation, Expr, Operation, Term |
14 | 12 |
|
15 | 13 |
|
@@ -355,7 +353,7 @@ def analyze(self, bound_sig: inspect.BoundArguments) -> frozenset[Operation]: |
355 | 353 | else: |
356 | 354 | param_bound_vars = {param_value} |
357 | 355 | elif param_ordinal: # Only process if there's a Scoped annotation |
358 | | - # We can't use tree.flatten here because we want to be able |
| 356 | + # We can't use flatten here because we want to be able |
359 | 357 | # to see dict keys |
360 | 358 | def extract_operations(obj): |
361 | 359 | if isinstance(obj, Operation): |
@@ -662,7 +660,9 @@ def func() -> t: # type: ignore |
662 | 660 | def _[**P, T](t: Callable[P, T], *, name: str | None = None) -> Operation[P, T]: |
663 | 661 | @functools.wraps(t) |
664 | 662 | def func(*args, **kwargs): |
665 | | - if not any(isinstance(a, Term) for a in tree.flatten((args, kwargs))): |
| 663 | + from effectful.ops.semantics import fvsof |
| 664 | + |
| 665 | + if not fvsof((args, kwargs)): |
666 | 666 | return t(*args, **kwargs) |
667 | 667 | else: |
668 | 668 | raise NotImplementedError |
@@ -872,18 +872,6 @@ def defterm[T](__dispatch: Callable[[type], Callable[[T], Expr[T]]], value: T): |
872 | 872 | return __dispatch(type(value))(value) |
873 | 873 |
|
874 | 874 |
|
875 | | -def _map_structure_and_keys(func, structure): |
876 | | - def _map_value(value): |
877 | | - if isinstance(value, dict): |
878 | | - return {func(k): v for k, v in value.items()} |
879 | | - elif not tree.is_nested(value): |
880 | | - return func(value) |
881 | | - else: |
882 | | - return value |
883 | | - |
884 | | - return tree.traverse(_map_value, structure, top_down=False) |
885 | | - |
886 | | - |
887 | 875 | @_CustomSingleDispatchCallable |
888 | 876 | def defdata[T]( |
889 | 877 | __dispatch: Callable[[type], Callable[..., Expr[T]]], |
@@ -960,9 +948,6 @@ def _(op, *args, **kwargs): |
960 | 948 | *{k: (v, kwarg_ctxs[k]) for k, v in kwargs.items()}.items(), |
961 | 949 | ): |
962 | 950 | if c: |
963 | | - v = _map_structure_and_keys( |
964 | | - lambda a: renaming.get(a, a) if isinstance(a, Operation) else a, v |
965 | | - ) |
966 | 951 | res = evaluate( |
967 | 952 | v, |
968 | 953 | intp={ |
@@ -1133,21 +1118,28 @@ def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool: |
1133 | 1118 | if isinstance(x, Term) and isinstance(other, Term): |
1134 | 1119 | op, args, kwargs = x.op, x.args, x.kwargs |
1135 | 1120 | op2, args2, kwargs2 = other.op, other.args, other.kwargs |
1136 | | - try: |
1137 | | - tree.assert_same_structure( |
1138 | | - (op, args, kwargs), (op2, args2, kwargs2), check_types=True |
1139 | | - ) |
1140 | | - except (TypeError, ValueError): |
1141 | | - return False |
1142 | | - return all( |
1143 | | - tree.flatten( |
1144 | | - tree.map_structure( |
1145 | | - syntactic_eq, (op, args, kwargs), (op2, args2, kwargs2) |
1146 | | - ) |
1147 | | - ) |
| 1121 | + return ( |
| 1122 | + op == op2 |
| 1123 | + and len(args) == len(args2) |
| 1124 | + and set(kwargs) == set(kwargs2) |
| 1125 | + and all(syntactic_eq(a, b) for a, b in zip(args, args2)) |
| 1126 | + and all(syntactic_eq(kwargs[k], kwargs2[k]) for k in kwargs) |
1148 | 1127 | ) |
1149 | 1128 | elif isinstance(x, Term) or isinstance(other, Term): |
1150 | 1129 | return False |
| 1130 | + elif isinstance(x, collections.abc.Mapping) and isinstance( |
| 1131 | + other, collections.abc.Mapping |
| 1132 | + ): |
| 1133 | + return all( |
| 1134 | + k in x and k in other and syntactic_eq(x[k], other[k]) |
| 1135 | + for k in set(x) | set(other) |
| 1136 | + ) |
| 1137 | + elif isinstance(x, collections.abc.Sequence) and isinstance( |
| 1138 | + other, collections.abc.Sequence |
| 1139 | + ): |
| 1140 | + return len(x) == len(other) and all( |
| 1141 | + syntactic_eq(a, b) for a, b in zip(x, other) |
| 1142 | + ) |
1151 | 1143 | else: |
1152 | 1144 | return x == other |
1153 | 1145 |
|
|
0 commit comments