Skip to content

Commit bc23e4c

Browse files
authored
Add type inference module for embedding polymorphic types (#299)
* first commit * format * test passes * types * add docstring with doctests for unify * more test cases * case * docstring and test cases for infer_return_type * canonicalize * fixes * fix error type * add _nested_type for sequences * make _nested_type singledispatch extensible * term case * nit * docstrings and tests for canonicalize and nested_type * no notimmplentederror * variadic params * remove dead code paths in unify * simplify unify * union * tweak * cleanup * fix none case * more aggressive canonicalize applied only during unification * update canonicalize test * remove dead path * paramspec failure * literal and optional * paramspec canonicalize * simplify alias * single unify call * fast paths * union handling * freshen * type of freshen * reorder * fixpoint in substitute * tweak types * empty params * tweak substitute type * doctest * separate nested_type * handle defaults * update * remove test ids * add to sphinx * fix doctests * fix type checking * str * lint and format * freshen -> _freshen * factor out freetypevars and susbtitute * truncate names * doctest * doctest * add compositional tests * add tests for function types * ellipsis and paramspec * variadic tuple logic and union tests * fix union unify pattern * Use unification to implement `Operation.__type_rule__` (#300) * Use infer_return_type to implement Operation.__type_rule__ * move some logic out of infer_return_type * dont add defaults * update * remove duplicate default param logic * fix and format * lint * fix semiring test * address comment * update to 3.12 * union * try truncating?? * tweaks * split up expression in _freshen * no parallel build * revert script * break expressions u * script * try again with old script?? * rewrite * simplify * finer error type * pin jax * use typing internal api for substitute and freetypevars, and make canonicalize more robust * internal api * revert api * add test * fix doctest * remove unused var * address some comments * default attr * union soundness * fix union and add variadic assertion * reinstate support for bounds to make tests pass * nodefault * lint * forwardref
1 parent c424c3c commit bc23e4c

8 files changed

Lines changed: 2632 additions & 81 deletions

File tree

docs/source/effectful.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,7 @@ Internals
9898
.. automodule:: effectful.internals.runtime
9999
:members:
100100
:undoc-members:
101+
102+
.. automodule:: effectful.internals.unification
103+
:members:
104+
:undoc-members:

docs/source/semi_ring.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,17 @@ def Let[S, T, A, B](
6262

6363

6464
@defop
65-
def Record[T](**kwargs: T) -> dict[str, T]:
65+
def Record[T](**kwargs: T) -> collections.abc.Mapping[str, T]:
6666
raise NotImplementedError
6767

6868

6969
@defop
70-
def Field[T](record: dict[str, T], key: str) -> T:
70+
def Field[T](record: collections.abc.Mapping[str, T], key: str) -> T:
7171
raise NotImplementedError
7272

7373

7474
@defop
75-
def Dict[K, V](*contents: Union[K, V]) -> SemiRingDict[K, V]:
75+
def Dict[K, V](*contents: tuple[K, V]) -> SemiRingDict[K, V]:
7676
raise NotImplementedError
7777

7878

@@ -92,20 +92,14 @@ def add[T](x: T, y: T) -> T:
9292
ops.Field = Field
9393

9494

95-
def eager_dict[K, V](*contents: Tuple[K, V]) -> SemiRingDict[K, V]:
96-
if not any(isinstance(v, Term) for v in contents):
97-
if len(contents) % 2 != 0:
98-
raise ValueError("Dict requires an even number of arguments")
99-
100-
kv = []
101-
for i in range(0, len(contents), 2):
102-
kv.append((contents[i], contents[i + 1]))
103-
return SemiRingDict(kv)
95+
def eager_dict[K, V](*contents: tuple[K, V]) -> SemiRingDict[K, V]:
96+
if not any(isinstance(v, Term) for kv in contents for v in kv):
97+
return SemiRingDict(list(contents))
10498
else:
10599
return fwd()
106100

107101

108-
def eager_record[T](**kwargs: T) -> dict[str, T]:
102+
def eager_record[T](**kwargs: T) -> collections.abc.Mapping[str, T]:
109103
if not any(isinstance(v, Term) for v in kwargs.values()):
110104
return dict(**kwargs)
111105
else:
@@ -215,7 +209,7 @@ def vertical_fusion[S, T](e1: T, x: Operation[[], T], e2: S) -> S:
215209
)
216210

217211
term: SemiRingDict[int, int] = Let(
218-
Sum(x(), k, v, Dict(k(), v() + 1)), y, Sum(y(), k, v, Dict(k(), v() + 1))
212+
Sum(x(), k, v, Dict((k(), v() + 1))), y, Sum(y(), k, v, Dict((k(), v() + 1)))
219213
)
220214

221215
print("Without optimization:", term)

0 commit comments

Comments
 (0)