Skip to content

Commit b332b0c

Browse files
authored
Remove dependency of ops on dm-tree (#323)
* Remove tree dependency from core * remove dataclass * syntactic_eq dataclass case removed * lint and revert * cases * lint * remove abstractmethod
1 parent bc23e4c commit b332b0c

3 files changed

Lines changed: 53 additions & 51 deletions

File tree

effectful/ops/semantics.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1+
import collections.abc
12
import contextlib
23
import functools
34
import types
45
import typing
56
from collections.abc import Callable
67
from typing import Any
78

8-
import tree
9-
109
from effectful.ops.syntax import deffn, defop
1110
from effectful.ops.types import Expr, Interpretation, Operation, Term
1211

@@ -68,7 +67,7 @@ def call[**P, T](fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
6867
}
6968
with handler(subs):
7069
return evaluate(body)
71-
elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))):
70+
elif not fvsof((fn, args, kwargs)):
7271
return fn(*args, **kwargs)
7372
else:
7473
raise NotImplementedError
@@ -246,18 +245,37 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
246245
6
247246
248247
"""
249-
if intp is None:
250-
from effectful.internals.runtime import get_interpretation
248+
from effectful.internals.runtime import get_interpretation, interpreter
251249

252-
intp = get_interpretation()
250+
if intp is not None:
251+
return interpreter(intp)(evaluate)(expr)
253252

254253
if isinstance(expr, Term):
255-
(args, kwargs) = tree.map_structure(
256-
functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs)
257-
)
258-
return apply.__default_rule__(intp, expr.op, *args, **kwargs)
259-
elif tree.is_nested(expr):
260-
return tree.map_structure(functools.partial(evaluate, intp=intp), expr)
254+
args = tuple(evaluate(arg) for arg in expr.args)
255+
kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()}
256+
return expr.op(*args, **kwargs)
257+
elif isinstance(expr, Operation):
258+
op_intp = get_interpretation().get(expr, expr)
259+
return op_intp if isinstance(op_intp, Operation) else expr # type: ignore
260+
elif isinstance(expr, collections.abc.Mapping):
261+
if isinstance(expr, collections.defaultdict):
262+
return type(expr)(expr.default_factory, evaluate(tuple(expr.items()))) # type: ignore
263+
elif isinstance(expr, types.MappingProxyType):
264+
return type(expr)(dict(evaluate(tuple(expr.items())))) # type: ignore
265+
else:
266+
return type(expr)(evaluate(tuple(expr.items()))) # type: ignore
267+
elif isinstance(expr, collections.abc.Sequence):
268+
if isinstance(expr, str | bytes):
269+
return typing.cast(T, expr) # mypy doesnt like ignore here, so we use cast
270+
else:
271+
return type(expr)(evaluate(item) for item in expr) # type: ignore
272+
elif isinstance(expr, collections.abc.Set):
273+
if isinstance(expr, collections.abc.ItemsView | collections.abc.KeysView):
274+
return {evaluate(item) for item in expr} # type: ignore
275+
else:
276+
return type(expr)(evaluate(item) for item in expr) # type: ignore
277+
elif isinstance(expr, collections.abc.ValuesView):
278+
return [evaluate(item) for item in expr] # type: ignore
261279
else:
262280
return expr
263281

effectful/ops/syntax.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from collections.abc import Callable, Iterable, Mapping
99
from typing import Annotated, Concatenate
1010

11-
import tree
12-
1311
from effectful.ops.types import Annotation, Expr, Operation, Term
1412

1513

@@ -355,7 +353,7 @@ def analyze(self, bound_sig: inspect.BoundArguments) -> frozenset[Operation]:
355353
else:
356354
param_bound_vars = {param_value}
357355
elif param_ordinal: # Only process if there's a Scoped annotation
358-
# We can't use tree.flatten here because we want to be able
356+
# We can't use flatten here because we want to be able
359357
# to see dict keys
360358
def extract_operations(obj):
361359
if isinstance(obj, Operation):
@@ -662,7 +660,9 @@ def func() -> t: # type: ignore
662660
def _[**P, T](t: Callable[P, T], *, name: str | None = None) -> Operation[P, T]:
663661
@functools.wraps(t)
664662
def func(*args, **kwargs):
665-
if not any(isinstance(a, Term) for a in tree.flatten((args, kwargs))):
663+
from effectful.ops.semantics import fvsof
664+
665+
if not fvsof((args, kwargs)):
666666
return t(*args, **kwargs)
667667
else:
668668
raise NotImplementedError
@@ -872,18 +872,6 @@ def defterm[T](__dispatch: Callable[[type], Callable[[T], Expr[T]]], value: T):
872872
return __dispatch(type(value))(value)
873873

874874

875-
def _map_structure_and_keys(func, structure):
876-
def _map_value(value):
877-
if isinstance(value, dict):
878-
return {func(k): v for k, v in value.items()}
879-
elif not tree.is_nested(value):
880-
return func(value)
881-
else:
882-
return value
883-
884-
return tree.traverse(_map_value, structure, top_down=False)
885-
886-
887875
@_CustomSingleDispatchCallable
888876
def defdata[T](
889877
__dispatch: Callable[[type], Callable[..., Expr[T]]],
@@ -960,9 +948,6 @@ def _(op, *args, **kwargs):
960948
*{k: (v, kwarg_ctxs[k]) for k, v in kwargs.items()}.items(),
961949
):
962950
if c:
963-
v = _map_structure_and_keys(
964-
lambda a: renaming.get(a, a) if isinstance(a, Operation) else a, v
965-
)
966951
res = evaluate(
967952
v,
968953
intp={
@@ -1133,21 +1118,28 @@ def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool:
11331118
if isinstance(x, Term) and isinstance(other, Term):
11341119
op, args, kwargs = x.op, x.args, x.kwargs
11351120
op2, args2, kwargs2 = other.op, other.args, other.kwargs
1136-
try:
1137-
tree.assert_same_structure(
1138-
(op, args, kwargs), (op2, args2, kwargs2), check_types=True
1139-
)
1140-
except (TypeError, ValueError):
1141-
return False
1142-
return all(
1143-
tree.flatten(
1144-
tree.map_structure(
1145-
syntactic_eq, (op, args, kwargs), (op2, args2, kwargs2)
1146-
)
1147-
)
1121+
return (
1122+
op == op2
1123+
and len(args) == len(args2)
1124+
and set(kwargs) == set(kwargs2)
1125+
and all(syntactic_eq(a, b) for a, b in zip(args, args2))
1126+
and all(syntactic_eq(kwargs[k], kwargs2[k]) for k in kwargs)
11481127
)
11491128
elif isinstance(x, Term) or isinstance(other, Term):
11501129
return False
1130+
elif isinstance(x, collections.abc.Mapping) and isinstance(
1131+
other, collections.abc.Mapping
1132+
):
1133+
return all(
1134+
k in x and k in other and syntactic_eq(x[k], other[k])
1135+
for k in set(x) | set(other)
1136+
)
1137+
elif isinstance(x, collections.abc.Sequence) and isinstance(
1138+
other, collections.abc.Sequence
1139+
):
1140+
return len(x) == len(other) and all(
1141+
syntactic_eq(a, b) for a, b in zip(x, other)
1142+
)
11511143
else:
11521144
return x == other
11531145

tests/test_ops_syntax.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from effectful.ops.syntax import (
1111
Scoped,
1212
_CustomSingleDispatchCallable,
13-
_map_structure_and_keys,
1413
deffn,
1514
defop,
1615
defstream,
@@ -111,13 +110,6 @@ def f(x):
111110
assert f_op != ff_op
112111

113112

114-
def test_map_structure_and_keys():
115-
s = {1: 2, 3: [4, 5, (6, {7: 8})]}
116-
expected = {2: 3, 4: [5, 6, (7, {8: 9})]}
117-
actual = _map_structure_and_keys(lambda x: x + 1, s)
118-
assert actual == expected
119-
120-
121113
def test_scoped_collections():
122114
"""Test that Scoped annotations work with tree-structured collections containing Operations."""
123115

0 commit comments

Comments
 (0)