Skip to content

Commit f38b185

Browse files
fixes
1 parent a4fa855 commit f38b185

4 files changed

Lines changed: 26 additions & 26 deletions

File tree

python/egglog/exp/array_api.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def _bool(
112112
rewrite(x == x).to(TRUE), # noqa: PLR0124
113113
rewrite(FALSE == TRUE).to(FALSE),
114114
rewrite(TRUE == FALSE).to(FALSE),
115-
rewrite(Boolean.if_(TRUE, bt, bf), subsume=False).to(bt()),
116-
rewrite(Boolean.if_(FALSE, bt, bf), subsume=False).to(bf()),
115+
rewrite(Boolean.if_(TRUE, bt, bf), subsume=True).to(bt()),
116+
rewrite(Boolean.if_(FALSE, bt, bf), subsume=True).to(bf()),
117117
rule(eq(Boolean(b)).to(Boolean(b1)), ne(b).to(b1)).then(panic("Different booleans cannot be equal")),
118118
]
119119

@@ -283,8 +283,8 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int, ot: Callable[[], Int], bt:
283283
yield rewrite(~Int(i)).to(Int(~i))
284284
yield rewrite(Int(i).__abs__()).to(Int(i.__abs__()))
285285

286-
yield rewrite(Int.if_(TRUE, ot, bt), subsume=False).to(ot())
287-
yield rewrite(Int.if_(FALSE, ot, bt), subsume=False).to(bt())
286+
yield rewrite(Int.if_(TRUE, ot, bt), subsume=True).to(ot())
287+
yield rewrite(Int.if_(FALSE, ot, bt), subsume=True).to(bt())
288288

289289
yield rewrite(o.__round__(OptionalInt.none)).to(o)
290290

@@ -841,8 +841,8 @@ def _tuple_int(
841841
yield rewrite(TupleInt(vs).length()).to(Int(vs.length()))
842842
yield rewrite(TupleInt(vs)[Int(k)]).to(vs[k])
843843

844-
yield rewrite(TupleInt.if_(TRUE, lt, lf), subsume=False).to(lt())
845-
yield rewrite(TupleInt.if_(FALSE, lt, lf), subsume=False).to(lf())
844+
yield rewrite(TupleInt.if_(TRUE, lt, lf), subsume=True).to(lt())
845+
yield rewrite(TupleInt.if_(FALSE, lt, lf), subsume=True).to(lf())
846846

847847
yield rewrite(TupleInt.fn(Int(k), idx_fn), subsume=True).to(TupleInt(k.range().map(lambda i: idx_fn(Int(i)))))
848848

@@ -964,8 +964,8 @@ def _tuple_tuple_int(
964964
TupleTupleInt(k.range().map(lambda i: idx_fn(Int(i))))
965965
)
966966

967-
yield rewrite(TupleTupleInt.if_(TRUE, lt, lf), subsume=False).to(lt())
968-
yield rewrite(TupleTupleInt.if_(FALSE, lt, lf), subsume=False).to(lf())
967+
yield rewrite(TupleTupleInt.if_(TRUE, lt, lf), subsume=True).to(lt())
968+
yield rewrite(TupleTupleInt.if_(FALSE, lt, lf), subsume=True).to(lf())
969969

970970

971971
class DType(Expr, ruleset=array_api_ruleset):
@@ -1172,8 +1172,8 @@ def _value(
11721172

11731173
yield rewrite(Value.from_float(Float.rational(BigRat(0, 1))) + v).to(v)
11741174

1175-
yield rewrite(Value.if_(TRUE, vt, v1t)).to(vt())
1176-
yield rewrite(Value.if_(FALSE, vt, v1t)).to(v1t())
1175+
yield rewrite(Value.if_(TRUE, vt, v1t), subsume=True).to(vt())
1176+
yield rewrite(Value.if_(FALSE, vt, v1t), subsume=True).to(v1t())
11771177

11781178
# ==
11791179
yield rewrite(Value.from_int(i) == Value.from_int(i1)).to(i == i1)
@@ -1337,8 +1337,8 @@ def _tuple_value(
13371337

13381338
yield rewrite(TupleValue.fn(Int(k), idx_fn), subsume=True).to(TupleValue(k.range().map(lambda i: idx_fn(Int(i)))))
13391339

1340-
yield rewrite(TupleValue.if_(TRUE, lt, lf), subsume=False).to(lt())
1341-
yield rewrite(TupleValue.if_(FALSE, lt, lf), subsume=False).to(lf())
1340+
yield rewrite(TupleValue.if_(TRUE, lt, lf), subsume=True).to(lt())
1341+
yield rewrite(TupleValue.if_(FALSE, lt, lf), subsume=True).to(lf())
13421342

13431343

13441344
@function
@@ -1857,7 +1857,7 @@ def _ndarray(
18571857
return [
18581858
rewrite(NDArray.fn(shape, dtype, idx_fn).shape, subsume=False).to(shape),
18591859
rewrite(NDArray.fn(shape, dtype, idx_fn).dtype, subsume=False).to(dtype),
1860-
rewrite(NDArray.fn(shape, dtype, idx_fn).index(idx), subsume=False).to(idx_fn(idx)),
1860+
rewrite(NDArray.fn(shape, dtype, idx_fn).index(idx), subsume=True).to(idx_fn(idx)),
18611861
rewrite(NDArray(rv).shape, subsume=False).to(rv.shape),
18621862
rewrite(NDArray(rv).index(TupleInt(vi)), subsume=False).to(rv[vi]),
18631863
# TODO: Special case scalar ops for now
@@ -1875,8 +1875,8 @@ def _ndarray(
18751875
# Transpose of transpose is the original array
18761876
rewrite(x.T.T, subsume=False).to(x),
18771877
# if_
1878-
rewrite(NDArray.if_(TRUE, xt, x1t), subsume=False).to(xt()),
1879-
rewrite(NDArray.if_(FALSE, xt, x1t), subsume=False).to(x1t()),
1878+
rewrite(NDArray.if_(TRUE, xt, x1t), subsume=True).to(xt()),
1879+
rewrite(NDArray.if_(FALSE, xt, x1t), subsume=True).to(x1t()),
18801880
# to RecursiveValue,
18811881
# only trigger if size smaller than 10 to avoid blowing up
18821882
rule(

python/egglog/exp/array_api_loopnest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _loopnest_api_ruleset(lna: LoopNestAPI, dim: Int, ti: TupleInt, idx_fn: Call
6565
# from_tuple
6666
yield rewrite(LoopNestAPI.from_tuple(TupleInt(())), subsume=True).to(OptionalLoopNestAPI.NONE)
6767
yield rewrite(LoopNestAPI.from_tuple(TupleInt(vs)), subsume=True).to(
68-
OptionalLoopNestAPI(LoopNestAPI(vs[0], LoopNestAPI.from_tuple(TupleInt(vs.remove(0))))),
68+
OptionalLoopNestAPI(LoopNestAPI(vs[vs.length() - 1], LoopNestAPI.from_tuple(TupleInt(vs.pop())))),
6969
vs.length() > 0,
7070
)
7171
# get_dims

python/egglog/exp/array_api_program_gen.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _tuple_int_program(
9797

9898
yield rewrite(tuple_int_foldl_program(TupleInt(()), f, init)).to(init)
9999
yield rewrite(tuple_int_foldl_program(TupleInt(vi), f, init)).to(
100-
f(tuple_int_foldl_program(vi.remove(0), f, init), vi[vi.length() - 1]), vi.length() > 0
100+
f(tuple_int_foldl_program(vi.pop(), f, init), vi[vi.length() - 1]), vi.length() > 0
101101
)
102102

103103

@@ -196,7 +196,7 @@ def _tuple_value_program(
196196

197197
yield rewrite(tuple_value_foldl_program(TupleValue(()), f, init)).to(init)
198198
yield rewrite(tuple_value_foldl_program(TupleValue(vv), f, init)).to(
199-
f(tuple_value_foldl_program(vv.remove(0), f, init), vv[vv.length() - 1]), vv.length() > 0
199+
f(tuple_value_foldl_program(vv.pop(), f, init), vv[vv.length() - 1]), vv.length() > 0
200200
)
201201

202202

@@ -220,7 +220,7 @@ def _tuple_ndarray_program(
220220

221221
yield rewrite(tuple_ndarray_foldl_program(TupleNDArray(()), f, init)).to(init)
222222
yield rewrite(tuple_ndarray_foldl_program(TupleNDArray(vn), f, init)).to(
223-
f(tuple_ndarray_foldl_program(vn.remove(0), f, init), vn[vn.length() - 1]), vn.length() > 0
223+
f(tuple_ndarray_foldl_program(vn.pop(), f, init), vn[vn.length() - 1]), vn.length() > 0
224224
)
225225

226226

@@ -561,7 +561,7 @@ def vec_recursive_value_program(x: Vec[RecursiveValue]) -> Program: ...
561561

562562
@array_api_program_gen_ruleset.register
563563
def _vec_recursive_value_program(v: Value, vv: Vec[RecursiveValue]):
564-
yield rewrite(vec_recursive_value_program(Vec.empty())).to(Program(""))
564+
yield rewrite(vec_recursive_value_program(Vec[RecursiveValue].empty())).to(Program(""))
565565
yield rewrite(vec_recursive_value_program(vv)).to(
566566
recursive_value_program(vv[0]) + ", " + vec_recursive_value_program(vv.remove(0)),
567567
)

python/tests/__snapshots__/test_array_api/TestLoopNest.test_index_codegen[expr].py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
_Value_1 = NDArray.var("X").index(
2-
TupleInt.from_vec(Vec(Int(0), Int(0), Int.var("i"), Int.var("j")))
2+
TupleInt(Vec(Int(0), Int(0), Int.var("i"), Int.var("j")))
33
)
44
_Value_2 = NDArray.var("X").index(
5-
TupleInt.from_vec(Vec(Int(0), Int(1), Int.var("i"), Int.var("j")))
5+
TupleInt(Vec(Int(0), Int(1), Int.var("i"), Int.var("j")))
66
)
77
_Value_3 = NDArray.var("X").index(
8-
TupleInt.from_vec(Vec(Int(1), Int(0), Int.var("i"), Int.var("j")))
8+
TupleInt(Vec(Int(1), Int(0), Int.var("i"), Int.var("j")))
99
)
1010
_Value_4 = NDArray.var("X").index(
11-
TupleInt.from_vec(Vec(Int(1), Int(1), Int.var("i"), Int.var("j")))
11+
TupleInt(Vec(Int(1), Int(1), Int.var("i"), Int.var("j")))
1212
)
1313
_Value_5 = NDArray.var("X").index(
14-
TupleInt.from_vec(Vec(Int(2), Int(0), Int.var("i"), Int.var("j")))
14+
TupleInt(Vec(Int(2), Int(0), Int.var("i"), Int.var("j")))
1515
)
1616
_Value_6 = NDArray.var("X").index(
17-
TupleInt.from_vec(Vec(Int(2), Int(1), Int.var("i"), Int.var("j")))
17+
TupleInt(Vec(Int(2), Int(1), Int.var("i"), Int.var("j")))
1818
)
1919
(
2020
(_Value_1.conj() * _Value_1).real()

0 commit comments

Comments
 (0)