Skip to content

Commit 8111388

Browse files
fix abs
1 parent e21d7d4 commit 8111388

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

python/egglog/exp/array_api.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
264264
yield rewrite(Int(i) << Int(j)).to(Int(i << j))
265265
yield rewrite(Int(i) >> Int(j)).to(Int(i >> j))
266266
yield rewrite(~Int(i)).to(Int(~i))
267-
yield rewrite(abs(Int(i))).to(Int(abs(i)))
267+
yield rewrite(Int(i).__abs__()).to(Int(i.__abs__()))
268268

269269
yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
270270
yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)
@@ -404,7 +404,7 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
404404
rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
405405
rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
406406
rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
407-
rewrite(abs(Float(f))).to(Float(abs(f))),
407+
rewrite(Float(f).__abs__()).to(Float(f.__abs__())),
408408
]
409409

410410

@@ -1819,6 +1819,10 @@ def square(x: NDArray) -> NDArray: ...
18191819
def any(x: NDArray) -> NDArray: ...
18201820

18211821

1822+
@function(egg_fn="ndarray-abs")
1823+
def abs(x: NDArray) -> NDArray: ...
1824+
1825+
18221826
@function(egg_fn="ndarray-log")
18231827
def log(x: NDArray) -> NDArray: ...
18241828

@@ -1921,7 +1925,9 @@ def vector_norm(x: NDArrayLike) -> NDArray:
19211925
x = cast("NDArray", x)
19221926
# Only works on vectors
19231927
return NDArray.scalar(
1924-
x.to_values().map_value(lambda v: abs(v) ** Value.float(Float(2.0))).foldl_value(Value.__add__, Value.float(0))
1928+
x.to_values()
1929+
.map_value(lambda v: v.__abs__() ** Value.float(Float(2.0)))
1930+
.foldl_value(Value.__add__, Value.float(0))
19251931
** Value.float(Float(0.5))
19261932
)
19271933

0 commit comments

Comments
 (0)