Skip to content

Commit e21d7d4

Browse files
tmp
1 parent c282b59 commit e21d7d4

13 files changed

Lines changed: 72 additions & 339 deletions

python/egglog/builtins.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,9 @@ def remove(self, index: i64Like) -> Vec[T]: ...
10211021
@method(egg_fn="vec-set")
10221022
def set(self, index: i64Like, value: T) -> Vec[T]: ...
10231023

1024+
@method(egg_fn="vec-union")
1025+
def __or__(self, other: Vec[T]) -> Vec[T]: ...
1026+
10241027

10251028
for sequence_type in (list, tuple):
10261029
converter(

python/egglog/exp/array_api.py

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,26 @@
11
"""
2-
2+
Experimental Array API support.
33
44
## Lists
55
66
Lists have two main constructors:
77
88
- `List(length, idx_fn)`
9-
- `List.EMPTY` / `initial.append(last)`
9+
- `List.from_vec(vs)`
1010
11-
This is so that they can be defined either with a known fixed integer length (the cons list type) or a symbolic
11+
This is so that they can be defined either with a known fixed integer length or a symbolic
1212
length that could not be resolved to an integer.
1313
14-
There are rewrites to convert between these constructors in both directions. The only limitation however is that
15-
`length` has to a real i64 in order to be converted to a cons list.
16-
17-
When you are writing a function that uses ints, feel free to the `__getitem__` or `length()` methods or match
18-
directly on `List()` constructor. If you can write your function using that interface please do. But for some other
19-
methods whether the resulting length/index function is dependent on the rest of it, you can only define it with a known
20-
length, so you can then use the const list constructors.
21-
22-
We also support creating lists from vectors. These can be converted one to one to the snoc list representation.
23-
24-
It is troublesome to have to redefine lists for every type. It would be nice to have generic types, but they are not implemented yet.
25-
26-
We are guaranteed that all lists with known lengths will be represented as cons/empty. To safely use lists, use
27-
the `.length` and `.__getitem__` methods, unless you want to depend on it having known length, in which
28-
case you can match directly on the cons list.
29-
30-
To be a list, you must implement two methods:
14+
Both constructors must implement two methods:
3115
3216
* `l.length() -> Int`
3317
* `l.__getitem__(i: Int) -> T`
3418
35-
There are three main types of constructors for lists which all implement these methods:
36-
37-
* Functional `List(length, idx_fn)`
38-
* cons (well reversed cons) lists `List.EMPTY` and `l.append(x)`
39-
* Vectors `List.from_vec(vec)`
40-
41-
Also all lists constructors must be converted to the functional representation, so that we can match on it
42-
and convert lists with known lengths into cons lists and into vectors.
43-
44-
This is necessary so that known length lists are properly materialized during extraction.
45-
46-
Q: Why are they implemented as SNOC lists instead of CONS lists?
47-
A: So that when converting from functional to lists we can use the same index function by starting at the end and folding
48-
that way recursively.
19+
Lists with a known length will be subsumed into the vector representation.
4920
21+
Lists that have vecs that are equal will have the elements unified.
5022
23+
Methods that transform lists should also subsume, so that the vector version will be preferred.
5124
"""
5225

5326
# mypy: disable-error-code="empty-body"
@@ -89,6 +62,8 @@
8962

9063

9164
class Boolean(Expr, ruleset=array_api_ruleset):
65+
NEVER: ClassVar[Boolean]
66+
9267
def __init__(self, value: BoolLike) -> None: ...
9368

9469
@method(preserve=True)
@@ -479,6 +454,7 @@ def __len__(self) -> int:
479454
def __iter__(self) -> Iterator[Int]:
480455
return iter(self.eval())
481456

457+
@method(merge=Vec.__or__) # type: ignore[prop-decorator]
482458
@property
483459
def to_vec(self) -> Vec[Int]: ...
484460

@@ -646,6 +622,7 @@ def __len__(self) -> int:
646622
def __iter__(self) -> Iterator[TupleInt]:
647623
return iter(self.eval())
648624

625+
@method(merge=Vec.__or__) # type: ignore[prop-decorator]
649626
@property
650627
def to_vec(self) -> Vec[TupleInt]: ...
651628

@@ -836,7 +813,8 @@ def bool(cls, b: BooleanLike) -> Value: ...
836813

837814
def isfinite(self) -> Boolean: ...
838815

839-
def __lt__(self, other: ValueLike) -> Boolean: ...
816+
# TODO: Fix
817+
def __lt__(self, other: ValueLike) -> Value: ...
840818
def __le__(self, other: ValueLike) -> Boolean: ...
841819
def __gt__(self, other: ValueLike) -> Boolean: ...
842820
def __ge__(self, other: ValueLike) -> Boolean: ...
@@ -941,8 +919,8 @@ def _value(i: Int, f: Float, b: Boolean, v: Value, v1: Value, v2: Value, i1: Int
941919
yield rewrite(Value.int(i) > Value.int(i1)).to(i > i1)
942920
yield rewrite(Value.float(f) > Value.float(f1)).to(f > f1)
943921
# <
944-
yield rewrite(Value.int(i) < Value.int(i1)).to(i < i1)
945-
yield rewrite(Value.float(f) < Value.float(f1)).to(f < f1)
922+
yield rewrite(Value.int(i) < Value.int(i1)).to(Value.bool(i < i1))
923+
yield rewrite(Value.float(f) < Value.float(f1)).to(Value.bool(f < f1))
946924

947925
# /
948926
yield rewrite(Value.float(f) / Value.float(f1)).to(Value.float(f / f1))
@@ -1573,6 +1551,7 @@ def __len__(self) -> int:
15731551
def __iter__(self) -> Iterator[NDArray]:
15741552
return iter(self.eval())
15751553

1554+
@method(merge=Vec.__or__) # type: ignore[prop-decorator]
15761555
@property
15771556
def to_vec(self) -> Vec[NDArray]: ...
15781557

@@ -1933,9 +1912,9 @@ def vector_norm(x: NDArrayLike) -> NDArray:
19331912
"""
19341913
https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.vector_norm.html
19351914
TODO: support axis
1936-
>>> x = NDArray.vector([1, 2, 3, 4, 5, 6, 7, 8, 9])
1937-
>>> vector_norm(x).eval_numpy("float64")
1938-
array(16.88194302)
1915+
# >>> x = NDArray.vector([1, 2, 3, 4, 5, 6, 7, 8, 9])
1916+
# >>> vector_norm(x).eval_numpy("float64")
1917+
# array(16.88194302)
19391918
"""
19401919
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html#numpy.linalg.norm
19411920
# sum(abs(x)**ord)**(1./ord) where ord=2
@@ -2068,7 +2047,7 @@ def _interval_analaysis(
20682047
NDArray.scalar(Value.bool(possible_values(x.index(ALL_INDICES).to_truthy_value).contains(Value.bool(TRUE))))
20692048
),
20702049
# Indexing x < y is the same as broadcasting the index and then indexing both and then comparing
2071-
rewrite((x < y).index(idx)).to(Value.bool(x_value < y_value)),
2050+
rewrite((x < y).index(idx)).to(x_value < y_value),
20722051
# Same for x / y
20732052
rewrite((x / y).index(idx)).to(x_value / y_value),
20742053
# Indexing a scalar is the same as the scalar
@@ -2102,7 +2081,7 @@ def _interval_analaysis(
21022081
# Define v < 0 to be false, if greater_zero(v)
21032082
rule(
21042083
greater_zero(v),
2105-
eq(v1).to(Value.bool(v < Value.int(Int(0)))),
2084+
eq(v1).to(v < Value.int(Int(0))),
21062085
).then(
21072086
union(v1).with_(Value.bool(FALSE)),
21082087
),

python/egglog/exp/array_api_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def jit(
3131
fn_program = EvalProgram(program, {"np": np})
3232
egraph.register(fn_program)
3333
egraph.run(array_api_program_gen_schedule)
34-
return cast("X", fn_program.as_py_object.value)
34+
return cast("X", egraph.extract(fn_program.as_py_object).value)
3535

3636

3737
def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:

python/egglog/exp/array_api_program_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _value_program(i: Int, b: Boolean, f: Float, x: NDArray, v1: Value, v2: Valu
144144
yield rewrite(value_program(Value.float(f))).to(float_program(f))
145145
# Could add .item() but we usually dont need it.
146146
yield rewrite(value_program(x.to_value())).to(ndarray_program(x))
147-
yield rewrite(bool_program(v1 < v2)).to(Program("(") + value_program(v1) + " < " + value_program(v2) + ")")
147+
yield rewrite(value_program(v1 < v2)).to(Program("(") + value_program(v1) + " < " + value_program(v2) + ")")
148148
yield rewrite(value_program(v1 / v2)).to(Program("(") + value_program(v1) + " / " + value_program(v2) + ")")
149149
yield rewrite(value_program(v1 + v2)).to(Program("(") + value_program(v1) + " + " + value_program(v2) + ")")
150150
yield rewrite(value_program(v1 * v2)).to(Program("(") + value_program(v1) + " * " + value_program(v2) + ")")

python/egglog/pretty.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
]
2727
MAX_LINE_LENGTH = 88
2828
LINE_DIFFERENCE = 10
29-
BLACK_MODE = black.Mode(line_length=88)
29+
BLACK_MODE = black.Mode(line_length=MAX_LINE_LENGTH)
3030

3131
# Use this special character in place of the args, so that if the args are inlined
3232
# in the viz, they will replace it
@@ -97,7 +97,10 @@ def pretty_decl(
9797
expr = f"{wrapping_fn}({expr})"
9898
program = "\n".join([*pretty.statements, expr])
9999
# First unparse AST to get consistent formatting, then use black to format it nicely
100-
ast_tree = ast.parse(program, mode="exec")
100+
try:
101+
ast_tree = ast.parse(program, mode="exec")
102+
except SyntaxError:
103+
return program
101104
program = ast.unparse(ast_tree)
102105
try:
103106
# TODO: Try replacing with ruff for speed
Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
1-
_Value_1 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(0), Int(0), Int.var("i"), Int.var("j"))))
2-
_Value_2 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(0), Int(1), Int.var("i"), Int.var("j"))))
3-
_Value_3 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(1), Int(0), Int.var("i"), Int.var("j"))))
4-
_Value_4 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(1), Int(1), Int.var("i"), Int.var("j"))))
5-
_Value_5 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(2), Int(0), Int.var("i"), Int.var("j"))))
6-
_Value_6 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(2), Int(1), Int.var("i"), Int.var("j"))))
1+
_Value_1 = NDArray.var("X").index(
2+
TupleInt.from_vec(Vec(Int(0), Int(0), Int.var("i"), Int.var("j")))
3+
)
4+
_Value_2 = NDArray.var("X").index(
5+
TupleInt.from_vec(Vec(Int(0), Int(1), Int.var("i"), Int.var("j")))
6+
)
7+
_Value_3 = NDArray.var("X").index(
8+
TupleInt.from_vec(Vec(Int(1), Int(0), Int.var("i"), Int.var("j")))
9+
)
10+
_Value_4 = NDArray.var("X").index(
11+
TupleInt.from_vec(Vec(Int(1), Int(1), Int.var("i"), Int.var("j")))
12+
)
13+
_Value_5 = NDArray.var("X").index(
14+
TupleInt.from_vec(Vec(Int(2), Int(0), Int.var("i"), Int.var("j")))
15+
)
16+
_Value_6 = NDArray.var("X").index(
17+
TupleInt.from_vec(Vec(Int(2), Int(1), Int.var("i"), Int.var("j")))
18+
)
719
(
8-
(
9-
((((_Value_1.conj() * _Value_1).real() + (_Value_2.conj() * _Value_2).real()) + (_Value_3.conj() * _Value_3).real()) + (_Value_4.conj() * _Value_4).real())
10-
+ (_Value_5.conj() * _Value_5).real()
11-
)
20+
(_Value_1.conj() * _Value_1).real()
21+
+ (_Value_2.conj() * _Value_2).real()
22+
+ (_Value_3.conj() * _Value_3).real()
23+
+ (_Value_4.conj() * _Value_4).real()
24+
+ (_Value_5.conj() * _Value_5).real()
1225
+ (_Value_6.conj() * _Value_6).real()
1326
).sqrt()

python/tests/__snapshots__/test_array_api/test_jit[lda][code].py

Lines changed: 0 additions & 64 deletions
This file was deleted.

python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py

Lines changed: 0 additions & 95 deletions
This file was deleted.

0 commit comments

Comments
 (0)