1818
1919import legate .core .types as ty
2020from legate .core import broadcast , get_legate_runtime
21+ from legate .settings import settings
2122
2223from ..config import CuPyNumericOpCode
2324from ..runtime import runtime
@@ -46,8 +47,8 @@ def solve_single(library: Library, a: LogicalStore, b: LogicalStore) -> None:
4647 task .execute ()
4748
4849
49- MIN_SOLVE_TILE_SIZE = 512
50- MIN_SOLVE_MATRIX_SIZE = 2048
50+ MIN_SOLVE_TILE_SIZE = 2 if settings . test () else 512
51+ MIN_SOLVE_MATRIX_SIZE = 4 if settings . test () else 2048
5152
5253
5354def mp_solve (
@@ -59,14 +60,24 @@ def mp_solve(
5960 b : LogicalStore ,
6061 output : LogicalStore ,
6162) -> None :
62- task = get_legate_runtime ().create_auto_task (
63- library , CuPyNumericOpCode .MP_SOLVE
63+ # coloring via num_procs to get utilization
64+ initial_color_shape_x = runtime .num_gpus
65+ tilesize_x = (n + initial_color_shape_x - 1 ) // initial_color_shape_x
66+ color_shape_x = (n + tilesize_x - 1 ) // tilesize_x
67+
68+ task = get_legate_runtime ().create_manual_task (
69+ library , CuPyNumericOpCode .MP_SOLVE , (color_shape_x , 1 )
6470 )
6571 task .throws_exception (LinAlgError )
66- task .add_input (a )
67- task .add_input (b )
68- task .add_output (output )
69- task .add_alignment (output , b )
72+
73+ tiled_a = a .partition_by_tiling ((tilesize_x , n ))
74+ tiled_b = b .partition_by_tiling ((tilesize_x , nrhs ))
75+ tiled_output = output .partition_by_tiling ((tilesize_x , nrhs ))
76+
77+ task .add_input (tiled_a )
78+ task .add_input (tiled_b )
79+ task .add_output (tiled_output )
80+
7081 task .add_scalar_arg (n , ty .int64 )
7182 task .add_scalar_arg (nrhs , ty .int64 )
7283 task .add_scalar_arg (nb , ty .int64 )
0 commit comments