@@ -1857,101 +1857,141 @@ def contract(
18571857 assert n == rhs2 .shape [1 ]
18581858 assert k == rhs2 .shape [0 ]
18591859
1860- def rounding_divide (
1861- lhs : tuple [int , ...], rhs : tuple [int , ...]
1862- ) -> tuple [int , ...]:
1863- return tuple (
1864- (lh + rh - 1 ) // rh for (lh , rh ) in zip (lhs , rhs )
1865- )
1860+ # decide whether to run full 3D matmul vs k-batched
1861+ # choose batched version only if memory exceeds threshold
1862+ def use_legacy_matmul (
1863+ num_procs : int , m : int , n : int , k : int , itemsize : int
1864+ ) -> bool :
1865+ # runtime.num_procs == 1 --> legacy matmul
1866+ if not settings .test () and num_procs == 1 :
1867+ return True
1868+
1869+ # approximate whether batching would actually be triggered here
1870+ return (
1871+ m + n
1872+ ) * k * itemsize < settings .matmul_cache_size () * num_procs
1873+
1874+ use_3d_matmul = use_legacy_matmul (
1875+ runtime .num_procs , m , n , k , rhs1_thunk .dtype .itemsize
1876+ )
1877+
1878+ if use_3d_matmul :
1879+ lhs = lhs .promote (1 , k )
1880+ rhs1 = rhs1 .promote (2 , n )
1881+ rhs2 = rhs2 .promote (0 , m )
18661882
1867- # TODO: better heuristics
1868- def choose_2d_color_shape (
1869- shape : tuple [int , int ],
1870- ) -> tuple [int , int ]:
1871- # 1M elements, we should probably even go larger
1872- MIN_MATRIX_SIZE = 1 << 20
1873- # If the matrix is too small don't partition it at all
1874- if (not settings .test ()) and shape [0 ] * shape [
1875- 1
1876- ] <= MIN_MATRIX_SIZE :
1877- return (1 , 1 )
1878-
1879- # start with 1D and re-balance by powers of 2
1880- # (don't worry about other primes)
1881- color_shape = (runtime .num_procs , 1 )
1882- while (
1883- shape [0 ] / color_shape [0 ]
1884- < 2 * shape [1 ] / color_shape [1 ]
1885- and color_shape [0 ] % 2 == 0
1886- ):
1887- color_shape = (color_shape [0 ] // 2 , color_shape [1 ] * 2 )
1888-
1889- return color_shape
1890-
1891- # TODO: better heuristics?
1892- def choose_batchsize (
1893- tilesize : tuple [int , ...], k : int , itemsize : int
1894- ) -> int :
1895- # don't batch in case we only have 1 proc
1896- if runtime .num_procs == 1 :
1897- return k
1898-
1899- # default corresponds to 128MB (to store A and B tile)
1900- from ..settings import settings
1901-
1902- assert len (tilesize ) >= 2
1903- max_elements_per_tile = (
1904- settings .matmul_cache_size () // itemsize
1883+ task = legate_runtime .create_auto_task (
1884+ self .library , CuPyNumericOpCode .MATMUL
19051885 )
1906- total_elements_rhs = (tilesize [0 ] + tilesize [1 ]) * k
1907- num_batches = rounding_divide (
1908- (total_elements_rhs ,), (max_elements_per_tile ,)
1909- )[0 ]
1910- batch_size = rounding_divide ((k ,), (num_batches ,))[0 ]
1911-
1912- return batch_size
1913-
1914- # choose color-shape/k_batch_size
1915- initial_color_shape = choose_2d_color_shape ((m , n ))
1916- tile_shape = rounding_divide ((m , n ), initial_color_shape )
1917- color_shape = rounding_divide ((m , n ), tile_shape )
1918- k_batch_size = choose_batchsize (
1919- tile_shape , k , rhs1_thunk .dtype .itemsize
1920- )
1921- k_color = rounding_divide ((k ,), (k_batch_size ,))
1886+ p_lhs = task .add_reduction (lhs , ReductionOpKind .ADD )
1887+ p_rhs1 = task .add_input (rhs1 )
1888+ p_rhs2 = task .add_input (rhs2 )
19221889
1923- # initial partition of lhs defined py tile-shape
1924- tiled_lhs = lhs .partition_by_tiling (tile_shape )
1925- tiled_rhs1 = rhs1 .partition_by_tiling (
1926- (tile_shape [0 ], k_batch_size )
1927- )
1928- tiled_rhs2 = rhs2 .partition_by_tiling (
1929- (k_batch_size , tile_shape [1 ])
1930- )
1890+ # specify unbatched matrix multiplication:
1891+ unbatched = 1
1892+ task .add_scalar_arg (unbatched , ty .uint32 )
1893+
1894+ task .add_constraint (align (p_lhs , p_rhs1 ))
1895+ task .add_constraint (align (p_lhs , p_rhs2 ))
1896+ task .execute ()
1897+
1898+ else :
1899+ # batched matmul
1900+ #
1901+
1902+ def rounding_divide (
1903+ lhs : tuple [int , ...], rhs : tuple [int , ...]
1904+ ) -> tuple [int , ...]:
1905+ return tuple (
1906+ (lh + rh - 1 ) // rh for (lh , rh ) in zip (lhs , rhs )
1907+ )
1908+
1909+ # manually create 2d color shape with num_procs colors
1910+ def choose_2d_color_shape (
1911+ shape : tuple [int , int ],
1912+ ) -> tuple [int , int ]:
1913+ # start with 1D and re-balance by powers of 2
1914+ # (don't worry about other primes)
1915+ color_shape = (runtime .num_procs , 1 )
1916+ while (
1917+ shape [0 ] / color_shape [0 ]
1918+ < 2 * shape [1 ] / color_shape [1 ]
1919+ and color_shape [0 ] % 2 == 0
1920+ ):
1921+ color_shape = (
1922+ color_shape [0 ] // 2 ,
1923+ color_shape [1 ] * 2 ,
1924+ )
19311925
1932- def run_matmul_for_batch (
1933- tiled_lhs : LogicalStorePartition ,
1934- tiled_rhs1 : LogicalStorePartition ,
1935- tiled_rhs2 : LogicalStorePartition ,
1936- i : int ,
1937- ) -> None :
1938- manual_task = legate_runtime .create_manual_task (
1939- self .library , CuPyNumericOpCode .MATMUL , color_shape
1926+ return color_shape
1927+
1928+ # For a given tilesize choose a batchsize to split the
1929+ # k-dimension into parts that will keep the partitions
1930+ # of A and B below the settings.matmul_cache_size()
1931+ def choose_batchsize (
1932+ tilesize : tuple [int , ...], k : int , itemsize : int
1933+ ) -> int :
1934+ # don't batch in case we only have 1 proc
1935+ if runtime .num_procs == 1 :
1936+ return k
1937+
1938+ assert len (tilesize ) >= 2
1939+ # default corresponds to 128MB (to store A and B tile)
1940+ max_elements_per_tile = (
1941+ settings .matmul_cache_size () // itemsize
1942+ )
1943+ total_elements_rhs = (tilesize [0 ] + tilesize [1 ]) * k
1944+ num_batches = rounding_divide (
1945+ (total_elements_rhs ,), (max_elements_per_tile ,)
1946+ )[0 ]
1947+ # even out batches
1948+ batch_size = rounding_divide ((k ,), (num_batches ,))[0 ]
1949+
1950+ return batch_size
1951+
1952+ # choose color-shape/k_batch_size
1953+ initial_color_shape = choose_2d_color_shape ((m , n ))
1954+ tile_shape = rounding_divide ((m , n ), initial_color_shape )
1955+ color_shape = rounding_divide ((m , n ), tile_shape )
1956+ k_batch_size = choose_batchsize (
1957+ tile_shape , k , rhs1_thunk .dtype .itemsize
19401958 )
1959+ k_color = rounding_divide ((k ,), (k_batch_size ,))
19411960
1942- manual_task . add_output ( tiled_lhs )
1943- manual_task . add_input ( tiled_lhs )
1944- manual_task . add_input (
1945- tiled_rhs1 , ( dimension ( 0 ), constant ( i ) )
1961+ # initial partition of lhs defined py tile-shape
1962+ tiled_lhs = lhs . partition_by_tiling ( tile_shape )
1963+ tiled_rhs1 = rhs1 . partition_by_tiling (
1964+ ( tile_shape [ 0 ], k_batch_size )
19461965 )
1947- manual_task . add_input (
1948- tiled_rhs2 , ( constant ( i ), dimension ( 1 ) )
1966+ tiled_rhs2 = rhs2 . partition_by_tiling (
1967+ ( k_batch_size , tile_shape [ 1 ] )
19491968 )
19501969
1951- manual_task .execute ()
1952-
1953- for i in range (0 , k_color [0 ]):
1954- run_matmul_for_batch (tiled_lhs , tiled_rhs1 , tiled_rhs2 , i )
1970+ def run_matmul_for_batch (
1971+ tiled_lhs : LogicalStorePartition ,
1972+ tiled_rhs1 : LogicalStorePartition ,
1973+ tiled_rhs2 : LogicalStorePartition ,
1974+ i : int ,
1975+ ) -> None :
1976+ manual_task = legate_runtime .create_manual_task (
1977+ self .library , CuPyNumericOpCode .MATMUL , color_shape
1978+ )
1979+
1980+ manual_task .add_output (tiled_lhs )
1981+ manual_task .add_input (tiled_lhs )
1982+ manual_task .add_input (
1983+ tiled_rhs1 , (dimension (0 ), constant (i ))
1984+ )
1985+ manual_task .add_input (
1986+ tiled_rhs2 , (constant (i ), dimension (1 ))
1987+ )
1988+
1989+ manual_task .execute ()
1990+
1991+ for i in range (0 , k_color [0 ]):
1992+ run_matmul_for_batch (
1993+ tiled_lhs , tiled_rhs1 , tiled_rhs2 , i
1994+ )
19551995
19561996 else :
19571997 assert False
@@ -4216,48 +4256,3 @@ def stencil_hint(
42164256 legate_runtime .prefetch_bloated_instances (
42174257 self .base , low_offsets , high_offsets , False
42184258 )
4219-
4220- @auto_convert ("rhs1_thunk" , "rhs2_thunk" )
4221- def ts_matmul (self , rhs1_thunk : Any , rhs2_thunk : Any ) -> Any :
4222- lhs_thunk : NumPyThunk = self
4223-
4224- # Clear output array
4225- lhs_thunk .fill (np .array (0 , dtype = lhs_thunk .dtype ))
4226- lhs = lhs_thunk .base # type: ignore
4227-
4228- rhs1 = rhs1_thunk .base
4229- rhs2 = rhs2_thunk .base
4230-
4231- m = lhs .shape [0 ]
4232- n = lhs .shape [1 ]
4233- k = rhs1 .shape [1 ]
4234- unbatched = 1
4235-
4236- assert m == rhs1 .shape [0 ]
4237- assert n == rhs2 .shape [1 ]
4238- assert k == rhs2 .shape [0 ]
4239- lhs = lhs .promote (1 , k )
4240- rhs1 = rhs1 .promote (2 , n )
4241- rhs2 = rhs2 .promote (0 , m )
4242-
4243- task = legate_runtime .create_auto_task (
4244- self .library , CuPyNumericOpCode .MATMUL
4245- )
4246- p_lhs = task .add_reduction (lhs , ReductionOpKind .ADD )
4247- p_rhs1 = task .add_input (rhs1 )
4248- p_rhs2 = task .add_input (rhs2 )
4249- #
4250- # specify unbatched matrix multiplication:
4251- #
4252- task .add_scalar_arg (unbatched , ty .uint32 )
4253-
4254- task .add_constraint (align (p_lhs , p_rhs1 ))
4255- task .add_constraint (align (p_lhs , p_rhs2 ))
4256- #
4257- # additional constraints:
4258- #
4259- # task.add_constraint(broadcast(p_rhs1, (0,)))
4260- # task.add_constraint(broadcast(p_rhs2, (1,)))
4261- task .add_constraint (broadcast (p_lhs ))
4262- #
4263- task .execute ()
0 commit comments