Skip to content

Commit c089e16

Browse files
Done?
1 parent f38b185 commit c089e16

13 files changed

Lines changed: 562 additions & 46 deletions

python/egglog/egraph_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: # noqa: C9
397397
bindings.Relation(span(), egg_name, [self.type_ref_to_egg(a) for a in arg_types])
398398
)
399399
case ConstantDecl(tp, _):
400-
# Use constructor decleration instead of constant b/c constants cannot be extracted
400+
# Use constructor declaration instead of constant b/c constants cannot be extracted
401401
# https://github.com/egraphs-good/egglog/issues/334
402402
is_function = self.__egg_decls__._classes[tp.ident].builtin
403403
schema = bindings.Schema([], self.type_ref_to_egg(tp))

python/egglog/exp/array_api.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def _tuple_int(
836836
yield rule(eq(ti).to(TupleInt(vs)), eq(ti).to(TupleInt(vs2)), vs != vs2).then(vs | vs2)
837837

838838
yield rewrite(TupleInt.fn(i2, idx_fn).length(), subsume=False).to(i2)
839-
yield rewrite(TupleInt.fn(i2, idx_fn)[i], subsume=False).to(idx_fn(check_index(i2, i)))
839+
yield rewrite(TupleInt.fn(i2, idx_fn)[i], subsume=True).to(idx_fn(check_index(i2, i)))
840840

841841
yield rewrite(TupleInt(vs).length()).to(Int(vs.length()))
842842
yield rewrite(TupleInt(vs)[Int(k)]).to(vs[k])
@@ -955,7 +955,7 @@ def _tuple_tuple_int(
955955
yield rule(eq(ti).to(TupleTupleInt(vs)), eq(ti).to(TupleTupleInt(vs2)), vs != vs2).then(vs | vs2)
956956

957957
yield rewrite(TupleTupleInt.fn(i2, idx_fn).length(), subsume=False).to(i2)
958-
yield rewrite(TupleTupleInt.fn(i2, idx_fn)[i], subsume=False).to(idx_fn(check_index(i2, i)))
958+
yield rewrite(TupleTupleInt.fn(i2, idx_fn)[i], subsume=True).to(idx_fn(check_index(i2, i)))
959959

960960
yield rewrite(TupleTupleInt(vs).length(), subsume=False).to(Int(vs.length()))
961961
yield rewrite(TupleTupleInt(vs)[Int(k)], subsume=False).to(vs[k])
@@ -1330,7 +1330,7 @@ def _tuple_value(
13301330
yield rule(eq(ti).to(TupleValue(vs)), eq(ti).to(TupleValue(vs2)), vs != vs2).then(vs | vs2)
13311331

13321332
yield rewrite(TupleValue.fn(i2, idx_fn).length(), subsume=False).to(i2)
1333-
yield rewrite(TupleValue.fn(i2, idx_fn)[i], subsume=False).to(idx_fn(check_index(i2, i)))
1333+
yield rewrite(TupleValue.fn(i2, idx_fn)[i], subsume=True).to(idx_fn(check_index(i2, i)))
13341334

13351335
yield rewrite(TupleValue(vs).length(), subsume=False).to(Int(vs.length()))
13361336
yield rewrite(TupleValue(vs)[Int(k)], subsume=False).to(vs[k], k >= 0, k < vs.length())
@@ -1961,7 +1961,7 @@ def _tuple_ndarray(
19611961
):
19621962
yield rule(eq(ti).to(TupleNDArray(vs)), eq(ti).to(TupleNDArray(vs2)), vs != vs2).then(vs | vs2)
19631963
yield rewrite(TupleNDArray.fn(i2, idx_fn).length(), subsume=False).to(i2)
1964-
yield rewrite(TupleNDArray.fn(i2, idx_fn)[i], subsume=False).to(idx_fn(check_index(i2, i)))
1964+
yield rewrite(TupleNDArray.fn(i2, idx_fn)[i], subsume=True).to(idx_fn(check_index(i2, i)))
19651965

19661966
yield rewrite(TupleNDArray(vs).length(), subsume=False).to(Int(vs.length()))
19671967
yield rewrite(TupleNDArray(vs)[Int(k)], subsume=False).to(vs[k], k >= 0, k < vs.length())
@@ -2148,26 +2148,33 @@ def _astype(x: NDArray, dtype: DType, i: i64):
21482148
]
21492149

21502150

2151-
@function
2151+
@function(unextractable=True, ruleset=array_api_ruleset)
21522152
def unique_counts(x: NDArray) -> TupleNDArray:
21532153
"""
21542154
Returns the unique elements of an input array x and the corresponding counts for each unique element in x.
21552155
21562156
21572157
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_counts.html
21582158
"""
2159+
return TupleNDArray((unique_counts_elements(x), unique_counts_counts(x)))
2160+
2161+
2162+
@function
2163+
def unique_counts_elements(x: NDArray) -> NDArray: ...
2164+
2165+
2166+
@function
2167+
def unique_counts_counts(x: NDArray) -> NDArray: ...
21592168

21602169

21612170
@array_api_ruleset.register
21622171
def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value, dtype: DType):
21632172
return [
2164-
# rewrite(unique_counts(x).length()).to(Int(2)),
2165-
rewrite(unique_counts(x)).to(TupleNDArray.fn(2, unique_counts(x).__getitem__)),
21662173
# Sum of all unique counts is the size of the array
2167-
rewrite(sum(unique_counts(x)[Int(1)])).to(NDArray(Value.from_int(x.size))),
2174+
rewrite(sum(unique_counts_counts(x))).to(NDArray(Value.from_int(x.size))),
21682175
# Same but with astype in the middle
21692176
# TODO: Replace
2170-
rewrite(sum(astype(unique_counts(x)[Int(1)], dtype))).to(astype(NDArray(Value.from_int(x.size)), dtype)),
2177+
rewrite(sum(astype(unique_counts_counts(x), dtype))).to(astype(NDArray(Value.from_int(x.size)), dtype)),
21712178
]
21722179

21732180

@@ -2194,22 +2201,25 @@ def _abs(f: Float):
21942201
]
21952202

21962203

2197-
@function
2204+
@function(ruleset=array_api_ruleset, unextractable=True)
21982205
def unique_inverse(x: NDArray) -> TupleNDArray:
21992206
"""
22002207
Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x.
22012208
22022209
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_inverse.html
22032210
"""
2211+
return TupleNDArray((unique_values(x), unique_inverse_inverse_indices(x)))
2212+
2213+
2214+
@function
2215+
def unique_inverse_inverse_indices(x: NDArray) -> NDArray: ...
22042216

22052217

22062218
@array_api_ruleset.register
22072219
def _unique_inverse(x: NDArray, i: Int):
22082220
return [
2209-
# rewrite(unique_inverse(x).length()).to(Int(2)),
2210-
rewrite(unique_inverse(x)).to(TupleNDArray.fn(2, unique_inverse(x).__getitem__)),
22112221
# Shape of unique_inverse first element is same as shape of unique_values
2212-
rewrite(unique_inverse(x)[Int(0)]).to(unique_values(x)),
2222+
rewrite(unique_values(x)[Int(0)]).to(unique_values(x)),
22132223
]
22142224

22152225

@@ -2319,19 +2329,19 @@ def cross(a: NDArrayLike, b: NDArrayLike) -> NDArray:
23192329
linalg = sys.modules[__name__]
23202330

23212331

2322-
@function
2323-
def svd(x: NDArray, full_matrices: Boolean = TRUE) -> TupleNDArray:
2332+
def svd(x: NDArray, full_matrices: Boolean = TRUE) -> tuple[NDArray, NDArray, NDArray]:
23242333
"""
23252334
https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.svd.html
23262335
"""
2336+
res = svd_(x, full_matrices)
2337+
return (res[0], res[1], res[2])
23272338

23282339

2329-
@array_api_ruleset.register
2330-
def _linalg(x: NDArray, full_matrices: Boolean):
2331-
return [
2332-
# rewrite(svd(x, full_matrices).length()).to(Int(3)),
2333-
rewrite(svd(x, full_matrices)).to(TupleNDArray.fn(3, svd(x, full_matrices).__getitem__)),
2334-
]
2340+
@function
2341+
def svd_(x: NDArray, full_matrices: Boolean = TRUE) -> TupleNDArray:
2342+
"""
2343+
https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.svd.html
2344+
"""
23352345

23362346

23372347
@function(ruleset=array_api_ruleset, unextractable=True)
@@ -2427,7 +2437,7 @@ def _interval_analaysis(
24272437
# rule(eq(y).to(any(x)), ndarray_all_false(x)).then(union(y).with_(NDArray(Value.bool(FALSE)))),
24282438
# Indexing into unique counts counts are all positive
24292439
rule(
2430-
eq(v).to(unique_counts(x)[Int(1)].index(idx)),
2440+
eq(v).to(unique_counts_counts(x).index(idx)),
24312441
).then(greater_zero(v)),
24322442
# Min value preserved over astype
24332443
rule(
@@ -2710,15 +2720,20 @@ def try_evaling(expr: ExprWithValue[T_co]) -> T_co:
27102720
egraph = _get_current_egraph()
27112721
egraph.register(expr) # type: ignore[arg-type]
27122722
egraph.run(array_api_schedule)
2713-
# run on another e-graph to get around bug
2714-
# https://github.com/egraphs-good/egglog/issues/801
2715-
# return egraph.extract(expr).value # type: ignore[call-overload]
27162723
# egraph.display(n_inline_leaves=2, split_primitive_outputs=True, split_functions=[Int])
27172724
extracted_expr = egraph.extract(expr) # type: ignore[call-overload]
2725+
with contextlib.suppress(ExprValueError):
2726+
extracted_expr.value
2727+
# run on another e-graph to get around bug
2728+
# https://github.com/egraphs-good/egglog/issues/801
2729+
27182730
new_egraph = EGraph()
27192731
new_egraph.register(extracted_expr)
27202732
new_egraph.run(array_api_schedule)
27212733
return new_egraph.extract(extracted_expr).value
2734+
# except EggSmolError as e:
2735+
2736+
# raise e
27222737

27232738
# try:
27242739
# return egraph.extract(prim_expr).value # type: ignore[attr-defined]

python/egglog/exp/array_api_jit.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from egglog import EGraph, greedy_dag_cost_model
7+
from egglog import EGraph, bindings, greedy_dag_cost_model
88
from egglog.exp.array_api import NDArray, set_array_api_egraph
99
from egglog.exp.array_api_numba import array_api_numba_schedule
1010
from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
@@ -24,14 +24,21 @@ def jit(
2424
Jit compiles a function
2525
"""
2626
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
27+
egraph = EGraph()
2728
if handle_expr:
2829
handle_expr(res)
2930
if handle_optimized_expr:
3031
handle_optimized_expr(res_optimized)
3132
fn_program = EvalProgram(program, {"np": np})
3233
egraph.register(fn_program)
34+
3335
egraph.run(array_api_program_gen_schedule)
34-
return cast("X", egraph.extract(fn_program.as_py_object).value)
36+
37+
try:
38+
return cast("X", egraph.extract(fn_program.as_py_object).value)
39+
except bindings.EggSmolError as e:
40+
e.add_note(f"Failed to get py object from {egraph.extract(fn_program)}")
41+
raise
3542

3643

3744
def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:

python/egglog/exp/array_api_numba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def count_values(x: NDArrayLike, values: TupleValueLike) -> TupleValue:
4949
def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value):
5050
return [
5151
# The unique counts are the count of all the unique values
52-
rewrite(unique_counts(x)[1], subsume=True).to(
52+
rewrite(unique_counts_counts(x), subsume=True).to(
5353
NDArray.from_tuple_value(count_values(x, unique_values(x).to_tuple_values()))
5454
),
5555
]
@@ -60,5 +60,5 @@ def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value):
6060
def _unique_inverse(x: NDArray, i: Int):
6161
return [
6262
# Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
63-
rewrite(unique_inverse(x)[1] == i, subsume=True).to(x == unique_values(x).index((i,))),
63+
rewrite(unique_inverse_inverse_indices(x) == i, subsume=True).to(x == unique_values(x).index((i,))),
6464
]

python/egglog/exp/array_api_program_gen.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def int_program(x: Int) -> Program: ...
3535

3636

3737
@array_api_program_gen_ruleset.register
38-
def _int_program(i64_: i64, i: Int, j: Int, s: String):
38+
def _int_program(i64_: i64, i: Int, j: Int, s: String, b: Boolean, ti: Callable[[], Int], ti1: Callable[[], Int]):
3939
yield rewrite(int_program(Int.var(s))).to(Program(s, True))
4040
yield rewrite(int_program(Int(i64_))).to(Program(i64_.to_string()))
4141
yield rewrite(int_program(~i)).to(Program("~") + int_program(i))
@@ -57,6 +57,15 @@ def _int_program(i64_: i64, i: Int, j: Int, s: String):
5757
yield rewrite(int_program(i >> j)).to(Program("(") + int_program(i) + " >> " + int_program(j) + ")")
5858
yield rewrite(int_program(i // j)).to(Program("(") + int_program(i) + " // " + int_program(j) + ")")
5959

60+
assigned = int_program(j).assign()
61+
yield rewrite(int_program(check_index(i, j)), subsume=True).to(
62+
assigned.statement(Program("assert ") + assigned + " < " + int_program(i))
63+
)
64+
65+
yield rewrite(int_program(Int.if_(b, ti, ti1))).to(
66+
int_program(ti()) + " if " + bool_program(b) + " else " + int_program(ti1())
67+
)
68+
6069

6170
@function
6271
def program_if(b: BooleanLike, t: Callable[[], Program], f: Callable[[], Program]) -> Program: ...
@@ -168,7 +177,10 @@ def _value_program(i: Int, b: Boolean, f: Float, x: NDArray, v1: Value, v2: Valu
168177
yield rewrite(value_program(v1 * v2)).to(Program("(") + value_program(v1) + " * " + value_program(v2) + ")")
169178
yield rewrite(bool_program(v1.to_bool)).to(value_program(v1))
170179
yield rewrite(int_program(v1.to_int)).to(value_program(v1))
171-
yield rewrite(value_program(xs.index(ti))).to((ndarray_program(xs) + "[" + tuple_int_program(ti) + "]").assign())
180+
yield rewrite(value_program(xs.index(ti))).to(
181+
(ndarray_program(xs) + "[" + tuple_int_program(ti) + "]").assign(), ne(ti).to(TupleInt(()))
182+
)
183+
yield rewrite(value_program(xs.index(TupleInt(())))).to(ndarray_program(xs))
172184
yield rewrite(value_program(v1.sqrt())).to(Program("np.sqrt(") + value_program(v1) + ")")
173185
yield rewrite(value_program(v1.real())).to(Program("np.real(") + value_program(v1) + ")")
174186
yield rewrite(value_program(v1.conj())).to(Program("np.conj(") + value_program(v1) + ")")
@@ -516,8 +528,8 @@ def bin_op(res: NDArray, op: str) -> Command:
516528
optional_int_or_tuple_ != OptionalIntOrTuple.none,
517529
)
518530
# svd
519-
yield rewrite(tuple_ndarray_program(svd(x))).to((Program("np.linalg.svd(") + ndarray_program(x) + ")").assign())
520-
yield rewrite(tuple_ndarray_program(svd(x, FALSE))).to(
531+
yield rewrite(tuple_ndarray_program(svd_(x))).to((Program("np.linalg.svd(") + ndarray_program(x) + ")").assign())
532+
yield rewrite(tuple_ndarray_program(svd_(x, FALSE))).to(
521533
(Program("np.linalg.svd(") + ndarray_program(x) + ", full_matrices=False)").assign()
522534
)
523535
# sqrt
@@ -564,4 +576,5 @@ def _vec_recursive_value_program(v: Value, vv: Vec[RecursiveValue]):
564576
yield rewrite(vec_recursive_value_program(Vec[RecursiveValue].empty())).to(Program(""))
565577
yield rewrite(vec_recursive_value_program(vv)).to(
566578
recursive_value_program(vv[0]) + ", " + vec_recursive_value_program(vv.remove(0)),
579+
vv.length() > 0,
567580
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
def __fn(X, y):
2+
assert X.dtype == np.dtype(np.float64)
3+
assert X.shape == (150, 4, )
4+
assert np.all(np.isfinite(X))
5+
assert y.dtype == np.dtype(np.int64)
6+
assert y.shape == (150, )
7+
assert set(np.unique(y)) == set((0, 1, 2, ))
8+
_0 = y == np.array(0)
9+
_1 = np.sum(_0)
10+
_2 = y == np.array(1)
11+
_3 = np.sum(_2)
12+
_4 = y == np.array(2)
13+
_5 = np.sum(_4)
14+
_6 = np.array((_1, _3, _5, )).astype(np.dtype(np.float64))
15+
_7 = _6 / np.array(150)
16+
_8 = np.zeros((3, 4, ), dtype=np.dtype(np.float64))
17+
_9 = np.sum(X[_0], axis=0)
18+
_10 = _9 / np.array(X[_0].shape[0])
19+
_8[0, :,] = _10
20+
_11 = np.sum(X[_2], axis=0)
21+
_12 = _11 / np.array(X[_2].shape[0])
22+
_8[1, :,] = _12
23+
_13 = np.sum(X[_4], axis=0)
24+
_14 = _13 / np.array(X[_4].shape[0])
25+
_8[2, :,] = _14
26+
_15 = _7 @ _8
27+
_16 = X - _15
28+
_17 = np.sqrt(np.asarray(np.array(float(1 / 147)), np.dtype(np.float64)))
29+
_18 = X[_0] - _8[0, :,]
30+
_19 = X[_2] - _8[1, :,]
31+
_20 = X[_4] - _8[2, :,]
32+
_21 = np.concatenate((_18, _19, _20, ), axis=0)
33+
_22 = np.sum(_21, axis=0)
34+
_23 = _22 / np.array(_21.shape[0])
35+
_24 = np.expand_dims(_23, 0)
36+
_25 = _21 - _24
37+
_26 = np.square(_25)
38+
_27 = np.sum(_26, axis=0)
39+
_28 = _27 / np.array(_26.shape[0])
40+
_29 = np.sqrt(_28)
41+
_30 = _29 == np.array(0)
42+
_29[_30] = np.array((150 / 150))
43+
_31 = _21 / _29
44+
_32 = _17 * _31
45+
_33 = np.linalg.svd(_32, full_matrices=False)
46+
_34 = _33[1] > np.array(0.0001)
47+
_35 = _34.astype(np.dtype(np.int32))
48+
_36 = np.sum(_35)
49+
_37 = _33[2][:_36, :,] / _29
50+
_38 = _37.T / _33[1][:_36]
51+
_39 = np.array(150) * _7
52+
_40 = _39 * np.array(float(1 / 2))
53+
_41 = np.sqrt(_40)
54+
_42 = _8 - _15
55+
_43 = _41 * _42.T
56+
_44 = _43.T @ _38
57+
_45 = np.linalg.svd(_44, full_matrices=False)
58+
_46 = np.array(0.0001) * _45[1][0]
59+
_47 = _45[1] > _46
60+
_48 = _47.astype(np.dtype(np.int32))
61+
_49 = np.sum(_48)
62+
_50 = _38 @ _45[2].T[:, :_49,]
63+
_51 = _16 @ _50
64+
return _51[:, :2,]

0 commit comments

Comments
 (0)