Skip to content

Commit ff262cb

Browse files
authored
Add support for dataclass in evaluate (#324)
* Remove tree dependency from core * remove dataclass * syntactic_eq dataclass case removed * dataclass cases * lint and revert * cases * lint * remove abstractmethod * add unit test * fix copying * lint
1 parent b332b0c commit ff262cb

3 files changed

Lines changed: 63 additions & 1 deletion

File tree

effectful/ops/semantics.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections.abc
22
import contextlib
3+
import dataclasses
34
import functools
45
import types
56
import typing
@@ -276,8 +277,19 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
276277
return type(expr)(evaluate(item) for item in expr) # type: ignore
277278
elif isinstance(expr, collections.abc.ValuesView):
278279
return [evaluate(item) for item in expr] # type: ignore
280+
elif dataclasses.is_dataclass(expr) and not isinstance(expr, type):
281+
return typing.cast(
282+
T,
283+
dataclasses.replace(
284+
expr,
285+
**{
286+
field.name: evaluate(getattr(expr, field.name))
287+
for field in dataclasses.fields(expr)
288+
},
289+
),
290+
)
279291
else:
280-
return expr
292+
return typing.cast(T, expr)
281293

282294

283295
def typeof[T](term: Expr[T]) -> type[T]:

effectful/ops/syntax.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,19 @@ def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool:
11401140
return len(x) == len(other) and all(
11411141
syntactic_eq(a, b) for a, b in zip(x, other)
11421142
)
1143+
elif (
1144+
dataclasses.is_dataclass(x)
1145+
and not isinstance(x, type)
1146+
and dataclasses.is_dataclass(other)
1147+
and not isinstance(other, type)
1148+
):
1149+
return type(x) == type(other) and syntactic_eq(
1150+
{field.name: getattr(x, field.name) for field in dataclasses.fields(x)},
1151+
{
1152+
field.name: getattr(other, field.name)
1153+
for field in dataclasses.fields(other)
1154+
},
1155+
)
11431156
else:
11441157
return x == other
11451158

tests/test_ops_syntax.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dataclasses
12
import functools
23
import inspect
34
from collections.abc import Callable, Iterable, Iterator, Mapping
@@ -541,3 +542,39 @@ def test_defstream_1():
541542
# assert isinstance(tm_iter_next, numbers.Number) # TODO
542543
# assert issubclass(typeof(tm_iter_next), numbers.Number)
543544
assert tm_iter_next.op is next_
545+
546+
547+
def test_eval_dataclass():
548+
@dataclasses.dataclass
549+
class Point:
550+
x: int
551+
y: int
552+
553+
@dataclasses.dataclass
554+
class Line:
555+
start: Point
556+
end: Point
557+
558+
@dataclasses.dataclass
559+
class Lines:
560+
origin: Point
561+
lines: list[Line]
562+
563+
x, y = defop(int, name="x"), defop(int, name="y")
564+
p1 = Point(x(), y())
565+
p2 = Point(x() + 1, y() + 1)
566+
line = Line(p1, p2)
567+
lines = Lines(p1, [line])
568+
569+
assert {x, y} <= fvsof(lines)
570+
571+
assert p1 == lines.origin
572+
573+
with handler({x: lambda: 3, y: lambda: 4}):
574+
evaluated_lines = evaluate(lines)
575+
576+
assert isinstance(evaluated_lines, Lines)
577+
assert evaluated_lines == Lines(
578+
origin=Point(3, 4),
579+
lines=[Line(Point(3, 4), Point(4, 5))],
580+
)

0 commit comments

Comments
 (0)