@@ -836,7 +836,7 @@ def _tuple_int(
836836 yield rule (eq (ti ).to (TupleInt (vs )), eq (ti ).to (TupleInt (vs2 )), vs != vs2 ).then (vs | vs2 )
837837
838838 yield rewrite (TupleInt .fn (i2 , idx_fn ).length (), subsume = False ).to (i2 )
839- yield rewrite (TupleInt .fn (i2 , idx_fn )[i ], subsume = False ).to (idx_fn (check_index (i2 , i )))
839+ yield rewrite (TupleInt .fn (i2 , idx_fn )[i ], subsume = True ).to (idx_fn (check_index (i2 , i )))
840840
841841 yield rewrite (TupleInt (vs ).length ()).to (Int (vs .length ()))
842842 yield rewrite (TupleInt (vs )[Int (k )]).to (vs [k ])
@@ -955,7 +955,7 @@ def _tuple_tuple_int(
955955 yield rule (eq (ti ).to (TupleTupleInt (vs )), eq (ti ).to (TupleTupleInt (vs2 )), vs != vs2 ).then (vs | vs2 )
956956
957957 yield rewrite (TupleTupleInt .fn (i2 , idx_fn ).length (), subsume = False ).to (i2 )
958- yield rewrite (TupleTupleInt .fn (i2 , idx_fn )[i ], subsume = False ).to (idx_fn (check_index (i2 , i )))
958+ yield rewrite (TupleTupleInt .fn (i2 , idx_fn )[i ], subsume = True ).to (idx_fn (check_index (i2 , i )))
959959
960960 yield rewrite (TupleTupleInt (vs ).length (), subsume = False ).to (Int (vs .length ()))
961961 yield rewrite (TupleTupleInt (vs )[Int (k )], subsume = False ).to (vs [k ])
@@ -1330,7 +1330,7 @@ def _tuple_value(
13301330 yield rule (eq (ti ).to (TupleValue (vs )), eq (ti ).to (TupleValue (vs2 )), vs != vs2 ).then (vs | vs2 )
13311331
13321332 yield rewrite (TupleValue .fn (i2 , idx_fn ).length (), subsume = False ).to (i2 )
1333- yield rewrite (TupleValue .fn (i2 , idx_fn )[i ], subsume = False ).to (idx_fn (check_index (i2 , i )))
1333+ yield rewrite (TupleValue .fn (i2 , idx_fn )[i ], subsume = True ).to (idx_fn (check_index (i2 , i )))
13341334
13351335 yield rewrite (TupleValue (vs ).length (), subsume = False ).to (Int (vs .length ()))
13361336 yield rewrite (TupleValue (vs )[Int (k )], subsume = False ).to (vs [k ], k >= 0 , k < vs .length ())
@@ -1961,7 +1961,7 @@ def _tuple_ndarray(
19611961):
19621962 yield rule (eq (ti ).to (TupleNDArray (vs )), eq (ti ).to (TupleNDArray (vs2 )), vs != vs2 ).then (vs | vs2 )
19631963 yield rewrite (TupleNDArray .fn (i2 , idx_fn ).length (), subsume = False ).to (i2 )
1964- yield rewrite (TupleNDArray .fn (i2 , idx_fn )[i ], subsume = False ).to (idx_fn (check_index (i2 , i )))
1964+ yield rewrite (TupleNDArray .fn (i2 , idx_fn )[i ], subsume = True ).to (idx_fn (check_index (i2 , i )))
19651965
19661966 yield rewrite (TupleNDArray (vs ).length (), subsume = False ).to (Int (vs .length ()))
19671967 yield rewrite (TupleNDArray (vs )[Int (k )], subsume = False ).to (vs [k ], k >= 0 , k < vs .length ())
@@ -2148,26 +2148,33 @@ def _astype(x: NDArray, dtype: DType, i: i64):
21482148 ]
21492149
21502150
2151- @function
2151+ @function ( unextractable = True , ruleset = array_api_ruleset )
21522152def unique_counts (x : NDArray ) -> TupleNDArray :
21532153 """
21542154 Returns the unique elements of an input array x and the corresponding counts for each unique element in x.
21552155
21562156
21572157 https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_counts.html
21582158 """
2159+ return TupleNDArray ((unique_counts_elements (x ), unique_counts_counts (x )))
2160+
2161+
2162+ @function
2163+ def unique_counts_elements (x : NDArray ) -> NDArray : ...
2164+
2165+
2166+ @function
2167+ def unique_counts_counts (x : NDArray ) -> NDArray : ...
21592168
21602169
21612170@array_api_ruleset .register
21622171def _unique_counts (x : NDArray , c : NDArray , tv : TupleValue , v : Value , dtype : DType ):
21632172 return [
2164- # rewrite(unique_counts(x).length()).to(Int(2)),
2165- rewrite (unique_counts (x )).to (TupleNDArray .fn (2 , unique_counts (x ).__getitem__ )),
21662173 # Sum of all unique counts is the size of the array
2167- rewrite (sum (unique_counts (x )[ Int ( 1 )] )).to (NDArray (Value .from_int (x .size ))),
2174+ rewrite (sum (unique_counts_counts (x ))).to (NDArray (Value .from_int (x .size ))),
21682175 # Same but with astype in the middle
21692176 # TODO: Replace
2170- rewrite (sum (astype (unique_counts (x )[ Int ( 1 )] , dtype ))).to (astype (NDArray (Value .from_int (x .size )), dtype )),
2177+ rewrite (sum (astype (unique_counts_counts (x ), dtype ))).to (astype (NDArray (Value .from_int (x .size )), dtype )),
21712178 ]
21722179
21732180
@@ -2194,22 +2201,25 @@ def _abs(f: Float):
21942201 ]
21952202
21962203
2197- @function
2204+ @function ( ruleset = array_api_ruleset , unextractable = True )
21982205def unique_inverse (x : NDArray ) -> TupleNDArray :
21992206 """
22002207 Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x.
22012208
22022209 https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.unique_inverse.html
22032210 """
2211+ return TupleNDArray ((unique_values (x ), unique_inverse_inverse_indices (x )))
2212+
2213+
2214+ @function
2215+ def unique_inverse_inverse_indices (x : NDArray ) -> NDArray : ...
22042216
22052217
22062218@array_api_ruleset .register
22072219def _unique_inverse (x : NDArray , i : Int ):
22082220 return [
2209- # rewrite(unique_inverse(x).length()).to(Int(2)),
2210- rewrite (unique_inverse (x )).to (TupleNDArray .fn (2 , unique_inverse (x ).__getitem__ )),
22112221 # Shape of unique_inverse first element is same as shape of unique_values
2212- rewrite (unique_inverse (x )[Int (0 )]).to (unique_values (x )),
2222+ rewrite (unique_values (x )[Int (0 )]).to (unique_values (x )),
22132223 ]
22142224
22152225
@@ -2319,19 +2329,19 @@ def cross(a: NDArrayLike, b: NDArrayLike) -> NDArray:
23192329linalg = sys .modules [__name__ ]
23202330
23212331
2322- @function
2323- def svd (x : NDArray , full_matrices : Boolean = TRUE ) -> TupleNDArray :
2332+ def svd (x : NDArray , full_matrices : Boolean = TRUE ) -> tuple [NDArray , NDArray , NDArray ]:
23242333 """
23252334 https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.svd.html
23262335 """
2336+ res = svd_ (x , full_matrices )
2337+ return (res [0 ], res [1 ], res [2 ])
23272338
23282339
2329- @array_api_ruleset .register
2330- def _linalg (x : NDArray , full_matrices : Boolean ):
2331- return [
2332- # rewrite(svd(x, full_matrices).length()).to(Int(3)),
2333- rewrite (svd (x , full_matrices )).to (TupleNDArray .fn (3 , svd (x , full_matrices ).__getitem__ )),
2334- ]
2340+ @function
2341+ def svd_ (x : NDArray , full_matrices : Boolean = TRUE ) -> TupleNDArray :
2342+ """
2343+ https://data-apis.org/array-api/2022.12/extensions/generated/array_api.linalg.svd.html
2344+ """
23352345
23362346
23372347@function (ruleset = array_api_ruleset , unextractable = True )
@@ -2427,7 +2437,7 @@ def _interval_analaysis(
24272437 # rule(eq(y).to(any(x)), ndarray_all_false(x)).then(union(y).with_(NDArray(Value.bool(FALSE)))),
24282438 # Indexing into unique counts counts are all positive
24292439 rule (
2430- eq (v ).to (unique_counts (x )[ Int ( 1 )] .index (idx )),
2440+ eq (v ).to (unique_counts_counts (x ).index (idx )),
24312441 ).then (greater_zero (v )),
24322442 # Min value preserved over astype
24332443 rule (
@@ -2710,15 +2720,20 @@ def try_evaling(expr: ExprWithValue[T_co]) -> T_co:
27102720 egraph = _get_current_egraph ()
27112721 egraph .register (expr ) # type: ignore[arg-type]
27122722 egraph .run (array_api_schedule )
2713- # run on another e-graph to get around bug
2714- # https://github.com/egraphs-good/egglog/issues/801
2715- # return egraph.extract(expr).value # type: ignore[call-overload]
27162723 # egraph.display(n_inline_leaves=2, split_primitive_outputs=True, split_functions=[Int])
27172724 extracted_expr = egraph .extract (expr ) # type: ignore[call-overload]
2725+ with contextlib .suppress (ExprValueError ):
2726+ extracted_expr .value
2727+ # run on another e-graph to get around bug
2728+ # https://github.com/egraphs-good/egglog/issues/801
2729+
27182730 new_egraph = EGraph ()
27192731 new_egraph .register (extracted_expr )
27202732 new_egraph .run (array_api_schedule )
27212733 return new_egraph .extract (extracted_expr ).value
2734+ # except EggSmolError as e:
2735+
2736+ # raise e
27222737
27232738 # try:
27242739 # return egraph.extract(prim_expr).value # type: ignore[attr-defined]
0 commit comments