1919from egglog .exp .array_api_loopnest import *
2020from egglog .exp .array_api_numba import array_api_numba_schedule
2121from egglog .exp .array_api_program_gen import *
22- from egglog .exp .program_gen import Program
22+ from egglog .exp .program_gen import EvalProgram , Program
2323
2424some_shape = constant ("some_shape" , TupleInt )
2525some_dtype = constant ("some_dtype" , DType )
@@ -327,33 +327,19 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
327327 return globals [var ]
328328
329329
330- def load_source (fn_program : EvalProgram , egraph : EGraph ):
331- egraph .register (fn_program )
332- egraph .run (array_api_program_gen_schedule )
333- # dp the needed pieces in here for benchmarking
334- try :
335- return egraph .extract (fn_program .as_py_object ).eval ()
336- except Exception as err :
337- err .add_note (f"Failed to compile the program into a string: \n \n { egraph .extract (fn_program )} " )
338- egraph .display (split_primitive_outputs = True , n_inline_leaves = 3 , split_functions = [Program ])
339- raise
340-
341-
342- def lda (X , y ):
330+ def lda (X : NDArray , y : NDArray ):
343331 assume_dtype (X , X_np .dtype )
344332 assume_shape (X , X_np .shape )
345333 assume_isfinite (X )
346334
347335 assume_dtype (y , y_np .dtype )
348336 assume_shape (y , y_np .shape )
349- assume_value_one_of (y , tuple (map (int , np .unique (y_np )))) # type: ignore[arg-type]
337+ assume_value_one_of (y , tuple (map (int , np .unique (y_np ))))
350338 return run_lda (X , y )
351339
352340
353- def simplify_lda (egraph : EGraph , expr : NDArray ) -> NDArray :
354- egraph .register (expr )
355- egraph .run (array_api_numba_schedule )
356- return egraph .extract (expr )
341+ def lda_filled ():
342+ return lda (NDArray .var ("X" ), NDArray .var ("y" ))
357343
358344
359345@pytest .mark .parametrize (
@@ -398,21 +384,20 @@ class TestLDA:
398384 """
399385
400386 def test_trace (self , snapshot_py , benchmark ):
401- X = NDArray .var ("X" )
402- y = NDArray .var ("y" )
403- with EGraph ().set_current ():
404- X_r2 = benchmark (lda , X , y )
387+ @benchmark
388+ def X_r2 ():
389+ with EGraph ().set_current ():
390+ return lda_filled ()
391+
405392 res = str (X_r2 )
406- print (res )
407393 assert res == snapshot_py
408394
409395 def test_optimize (self , snapshot_py , benchmark ):
410396 egraph = EGraph ()
411- X = NDArray .var ("X" )
412- y = NDArray .var ("y" )
413397 with egraph .set_current ():
414- expr = lda (X , y )
415- simplified = benchmark (simplify_lda , egraph , expr )
398+ expr = lda_filled ()
399+ simplified = benchmark (egraph .simplify , expr , array_api_numba_schedule )
400+
416401 assert str (simplified ) == snapshot_py
417402
418403 # @pytest.mark.xfail(reason="Original source is not working")
@@ -423,18 +408,17 @@ def test_optimize(self, snapshot_py, benchmark):
423408
424409 def test_source_optimized (self , snapshot_py , benchmark ):
425410 egraph = EGraph ()
426- X = NDArray .var ("X" )
427- y = NDArray .var ("y" )
428411 with egraph .set_current ():
429- expr = lda (X , y )
430- optimized_expr = simplify_lda (egraph , expr )
431- egraph = EGraph ()
432- fn_program = ndarray_function_two (optimized_expr , NDArray .var ("X" ), NDArray .var ("y" ))
433- py_object = benchmark (load_source , fn_program , egraph )
412+ expr = lda_filled ()
413+ optimized_expr = egraph .simplify (expr , array_api_numba_schedule )
414+
415+ @benchmark
416+ def py_object ():
417+ fn_program = ndarray_function_two (optimized_expr , NDArray .var ("X" ), NDArray .var ("y" ))
418+ return try_evaling (array_api_program_gen_schedule , fn_program , fn_program .as_py_object )
419+
434420 assert np .allclose (py_object (X_np , y_np ), run_lda (X_np , y_np ))
435- with egraph .set_current ():
436- fn_object = cast (FunctionType , fn_program .as_py_object .eval ())
437- assert inspect .getsource (fn_object ) == snapshot_py
421+ assert inspect .getsource (py_object ) == snapshot_py
438422
439423 @pytest .mark .parametrize (
440424 "fn_thunk" ,
0 commit comments