@@ -112,8 +112,8 @@ def _bool(
112112 rewrite (x == x ).to (TRUE ), # noqa: PLR0124
113113 rewrite (FALSE == TRUE ).to (FALSE ),
114114 rewrite (TRUE == FALSE ).to (FALSE ),
115- rewrite (Boolean .if_ (TRUE , bt , bf ), subsume = False ).to (bt ()),
116- rewrite (Boolean .if_ (FALSE , bt , bf ), subsume = False ).to (bf ()),
115+ rewrite (Boolean .if_ (TRUE , bt , bf ), subsume = True ).to (bt ()),
116+ rewrite (Boolean .if_ (FALSE , bt , bf ), subsume = True ).to (bf ()),
117117 rule (eq (Boolean (b )).to (Boolean (b1 )), ne (b ).to (b1 )).then (panic ("Different booleans cannot be equal" )),
118118 ]
119119
@@ -283,8 +283,8 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int, ot: Callable[[], Int], bt:
283283 yield rewrite (~ Int (i )).to (Int (~ i ))
284284 yield rewrite (Int (i ).__abs__ ()).to (Int (i .__abs__ ()))
285285
286- yield rewrite (Int .if_ (TRUE , ot , bt ), subsume = False ).to (ot ())
287- yield rewrite (Int .if_ (FALSE , ot , bt ), subsume = False ).to (bt ())
286+ yield rewrite (Int .if_ (TRUE , ot , bt ), subsume = True ).to (ot ())
287+ yield rewrite (Int .if_ (FALSE , ot , bt ), subsume = True ).to (bt ())
288288
289289 yield rewrite (o .__round__ (OptionalInt .none )).to (o )
290290
@@ -841,8 +841,8 @@ def _tuple_int(
841841 yield rewrite (TupleInt (vs ).length ()).to (Int (vs .length ()))
842842 yield rewrite (TupleInt (vs )[Int (k )]).to (vs [k ])
843843
844- yield rewrite (TupleInt .if_ (TRUE , lt , lf ), subsume = False ).to (lt ())
845- yield rewrite (TupleInt .if_ (FALSE , lt , lf ), subsume = False ).to (lf ())
844+ yield rewrite (TupleInt .if_ (TRUE , lt , lf ), subsume = True ).to (lt ())
845+ yield rewrite (TupleInt .if_ (FALSE , lt , lf ), subsume = True ).to (lf ())
846846
847847 yield rewrite (TupleInt .fn (Int (k ), idx_fn ), subsume = True ).to (TupleInt (k .range ().map (lambda i : idx_fn (Int (i )))))
848848
@@ -964,8 +964,8 @@ def _tuple_tuple_int(
964964 TupleTupleInt (k .range ().map (lambda i : idx_fn (Int (i ))))
965965 )
966966
967- yield rewrite (TupleTupleInt .if_ (TRUE , lt , lf ), subsume = False ).to (lt ())
968- yield rewrite (TupleTupleInt .if_ (FALSE , lt , lf ), subsume = False ).to (lf ())
967+ yield rewrite (TupleTupleInt .if_ (TRUE , lt , lf ), subsume = True ).to (lt ())
968+ yield rewrite (TupleTupleInt .if_ (FALSE , lt , lf ), subsume = True ).to (lf ())
969969
970970
971971class DType (Expr , ruleset = array_api_ruleset ):
@@ -1172,8 +1172,8 @@ def _value(
11721172
11731173 yield rewrite (Value .from_float (Float .rational (BigRat (0 , 1 ))) + v ).to (v )
11741174
1175- yield rewrite (Value .if_ (TRUE , vt , v1t )).to (vt ())
1176- yield rewrite (Value .if_ (FALSE , vt , v1t )).to (v1t ())
1175+ yield rewrite (Value .if_ (TRUE , vt , v1t ), subsume = True ).to (vt ())
1176+ yield rewrite (Value .if_ (FALSE , vt , v1t ), subsume = True ).to (v1t ())
11771177
11781178 # ==
11791179 yield rewrite (Value .from_int (i ) == Value .from_int (i1 )).to (i == i1 )
@@ -1337,8 +1337,8 @@ def _tuple_value(
13371337
13381338 yield rewrite (TupleValue .fn (Int (k ), idx_fn ), subsume = True ).to (TupleValue (k .range ().map (lambda i : idx_fn (Int (i )))))
13391339
1340- yield rewrite (TupleValue .if_ (TRUE , lt , lf ), subsume = False ).to (lt ())
1341- yield rewrite (TupleValue .if_ (FALSE , lt , lf ), subsume = False ).to (lf ())
1340+ yield rewrite (TupleValue .if_ (TRUE , lt , lf ), subsume = True ).to (lt ())
1341+ yield rewrite (TupleValue .if_ (FALSE , lt , lf ), subsume = True ).to (lf ())
13421342
13431343
13441344@function
@@ -1857,7 +1857,7 @@ def _ndarray(
18571857 return [
18581858 rewrite (NDArray .fn (shape , dtype , idx_fn ).shape , subsume = False ).to (shape ),
18591859 rewrite (NDArray .fn (shape , dtype , idx_fn ).dtype , subsume = False ).to (dtype ),
1860- rewrite (NDArray .fn (shape , dtype , idx_fn ).index (idx ), subsume = False ).to (idx_fn (idx )),
1860+ rewrite (NDArray .fn (shape , dtype , idx_fn ).index (idx ), subsume = True ).to (idx_fn (idx )),
18611861 rewrite (NDArray (rv ).shape , subsume = False ).to (rv .shape ),
18621862 rewrite (NDArray (rv ).index (TupleInt (vi )), subsume = False ).to (rv [vi ]),
18631863 # TODO: Special case scalar ops for now
@@ -1875,8 +1875,8 @@ def _ndarray(
18751875 # Transpose of transpose is the original array
18761876 rewrite (x .T .T , subsume = False ).to (x ),
18771877 # if_
1878- rewrite (NDArray .if_ (TRUE , xt , x1t ), subsume = False ).to (xt ()),
1879- rewrite (NDArray .if_ (FALSE , xt , x1t ), subsume = False ).to (x1t ()),
1878+ rewrite (NDArray .if_ (TRUE , xt , x1t ), subsume = True ).to (xt ()),
1879+ rewrite (NDArray .if_ (FALSE , xt , x1t ), subsume = True ).to (x1t ()),
18801880 # to RecursiveValue,
18811881 # only trigger if size smaller than 10 to avoid blowing up
18821882 rule (
0 commit comments