Skip to content

Commit 5890e1d

Browse files
authored
Improve matmul heuristic - fall back to legacy matmul (nv-legate#1022)
* readd legacy 3D matrix multiply * remove Tall-skinny matmul workaround * align c++ dot matmul code with python * remove deprecated test, also ensure batched code executed during 1-proc tests * add comments, fix memory approximation
1 parent 1b0a6ac commit 5890e1d

7 files changed

Lines changed: 201 additions & 271 deletions

File tree

cupynumeric/_thunk/deferred.py

Lines changed: 126 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,101 +1857,141 @@ def contract(
18571857
assert n == rhs2.shape[1]
18581858
assert k == rhs2.shape[0]
18591859

1860-
def rounding_divide(
1861-
lhs: tuple[int, ...], rhs: tuple[int, ...]
1862-
) -> tuple[int, ...]:
1863-
return tuple(
1864-
(lh + rh - 1) // rh for (lh, rh) in zip(lhs, rhs)
1865-
)
1860+
# decide whether to run full 3D matmul vs k-batched
1861+
# choose batched version only if memory exceeds threshold
1862+
def use_legacy_matmul(
1863+
num_procs: int, m: int, n: int, k: int, itemsize: int
1864+
) -> bool:
1865+
# runtime.num_procs == 1 --> legacy matmul
1866+
if not settings.test() and num_procs == 1:
1867+
return True
1868+
1869+
# approximate whether batching would actually be triggered here
1870+
return (
1871+
m + n
1872+
) * k * itemsize < settings.matmul_cache_size() * num_procs
1873+
1874+
use_3d_matmul = use_legacy_matmul(
1875+
runtime.num_procs, m, n, k, rhs1_thunk.dtype.itemsize
1876+
)
1877+
1878+
if use_3d_matmul:
1879+
lhs = lhs.promote(1, k)
1880+
rhs1 = rhs1.promote(2, n)
1881+
rhs2 = rhs2.promote(0, m)
18661882

1867-
# TODO: better heuristics
1868-
def choose_2d_color_shape(
1869-
shape: tuple[int, int],
1870-
) -> tuple[int, int]:
1871-
# 1M elements, we should probably even go larger
1872-
MIN_MATRIX_SIZE = 1 << 20
1873-
# If the matrix is too small don't partition it at all
1874-
if (not settings.test()) and shape[0] * shape[
1875-
1
1876-
] <= MIN_MATRIX_SIZE:
1877-
return (1, 1)
1878-
1879-
# start with 1D and re-balance by powers of 2
1880-
# (don't worry about other primes)
1881-
color_shape = (runtime.num_procs, 1)
1882-
while (
1883-
shape[0] / color_shape[0]
1884-
< 2 * shape[1] / color_shape[1]
1885-
and color_shape[0] % 2 == 0
1886-
):
1887-
color_shape = (color_shape[0] // 2, color_shape[1] * 2)
1888-
1889-
return color_shape
1890-
1891-
# TODO: better heuristics?
1892-
def choose_batchsize(
1893-
tilesize: tuple[int, ...], k: int, itemsize: int
1894-
) -> int:
1895-
# don't batch in case we only have 1 proc
1896-
if runtime.num_procs == 1:
1897-
return k
1898-
1899-
# default corresponds to 128MB (to store A and B tile)
1900-
from ..settings import settings
1901-
1902-
assert len(tilesize) >= 2
1903-
max_elements_per_tile = (
1904-
settings.matmul_cache_size() // itemsize
1883+
task = legate_runtime.create_auto_task(
1884+
self.library, CuPyNumericOpCode.MATMUL
19051885
)
1906-
total_elements_rhs = (tilesize[0] + tilesize[1]) * k
1907-
num_batches = rounding_divide(
1908-
(total_elements_rhs,), (max_elements_per_tile,)
1909-
)[0]
1910-
batch_size = rounding_divide((k,), (num_batches,))[0]
1911-
1912-
return batch_size
1913-
1914-
# choose color-shape/k_batch_size
1915-
initial_color_shape = choose_2d_color_shape((m, n))
1916-
tile_shape = rounding_divide((m, n), initial_color_shape)
1917-
color_shape = rounding_divide((m, n), tile_shape)
1918-
k_batch_size = choose_batchsize(
1919-
tile_shape, k, rhs1_thunk.dtype.itemsize
1920-
)
1921-
k_color = rounding_divide((k,), (k_batch_size,))
1886+
p_lhs = task.add_reduction(lhs, ReductionOpKind.ADD)
1887+
p_rhs1 = task.add_input(rhs1)
1888+
p_rhs2 = task.add_input(rhs2)
19221889

1923-
# initial partition of lhs defined py tile-shape
1924-
tiled_lhs = lhs.partition_by_tiling(tile_shape)
1925-
tiled_rhs1 = rhs1.partition_by_tiling(
1926-
(tile_shape[0], k_batch_size)
1927-
)
1928-
tiled_rhs2 = rhs2.partition_by_tiling(
1929-
(k_batch_size, tile_shape[1])
1930-
)
1890+
# specify unbatched matrix multiplication:
1891+
unbatched = 1
1892+
task.add_scalar_arg(unbatched, ty.uint32)
1893+
1894+
task.add_constraint(align(p_lhs, p_rhs1))
1895+
task.add_constraint(align(p_lhs, p_rhs2))
1896+
task.execute()
1897+
1898+
else:
1899+
# batched matmul
1900+
#
1901+
1902+
def rounding_divide(
1903+
lhs: tuple[int, ...], rhs: tuple[int, ...]
1904+
) -> tuple[int, ...]:
1905+
return tuple(
1906+
(lh + rh - 1) // rh for (lh, rh) in zip(lhs, rhs)
1907+
)
1908+
1909+
# manually create 2d color shape with num_procs colors
1910+
def choose_2d_color_shape(
1911+
shape: tuple[int, int],
1912+
) -> tuple[int, int]:
1913+
# start with 1D and re-balance by powers of 2
1914+
# (don't worry about other primes)
1915+
color_shape = (runtime.num_procs, 1)
1916+
while (
1917+
shape[0] / color_shape[0]
1918+
< 2 * shape[1] / color_shape[1]
1919+
and color_shape[0] % 2 == 0
1920+
):
1921+
color_shape = (
1922+
color_shape[0] // 2,
1923+
color_shape[1] * 2,
1924+
)
19311925

1932-
def run_matmul_for_batch(
1933-
tiled_lhs: LogicalStorePartition,
1934-
tiled_rhs1: LogicalStorePartition,
1935-
tiled_rhs2: LogicalStorePartition,
1936-
i: int,
1937-
) -> None:
1938-
manual_task = legate_runtime.create_manual_task(
1939-
self.library, CuPyNumericOpCode.MATMUL, color_shape
1926+
return color_shape
1927+
1928+
# For a given tilesize choose a batchsize to split the
1929+
# k-dimension into parts that will keep the partitions
1930+
# of A and B below the settings.matmul_cache_size()
1931+
def choose_batchsize(
1932+
tilesize: tuple[int, ...], k: int, itemsize: int
1933+
) -> int:
1934+
# don't batch in case we only have 1 proc
1935+
if runtime.num_procs == 1:
1936+
return k
1937+
1938+
assert len(tilesize) >= 2
1939+
# default corresponds to 128MB (to store A and B tile)
1940+
max_elements_per_tile = (
1941+
settings.matmul_cache_size() // itemsize
1942+
)
1943+
total_elements_rhs = (tilesize[0] + tilesize[1]) * k
1944+
num_batches = rounding_divide(
1945+
(total_elements_rhs,), (max_elements_per_tile,)
1946+
)[0]
1947+
# even out batches
1948+
batch_size = rounding_divide((k,), (num_batches,))[0]
1949+
1950+
return batch_size
1951+
1952+
# choose color-shape/k_batch_size
1953+
initial_color_shape = choose_2d_color_shape((m, n))
1954+
tile_shape = rounding_divide((m, n), initial_color_shape)
1955+
color_shape = rounding_divide((m, n), tile_shape)
1956+
k_batch_size = choose_batchsize(
1957+
tile_shape, k, rhs1_thunk.dtype.itemsize
19401958
)
1959+
k_color = rounding_divide((k,), (k_batch_size,))
19411960

1942-
manual_task.add_output(tiled_lhs)
1943-
manual_task.add_input(tiled_lhs)
1944-
manual_task.add_input(
1945-
tiled_rhs1, (dimension(0), constant(i))
1961+
# initial partition of lhs defined py tile-shape
1962+
tiled_lhs = lhs.partition_by_tiling(tile_shape)
1963+
tiled_rhs1 = rhs1.partition_by_tiling(
1964+
(tile_shape[0], k_batch_size)
19461965
)
1947-
manual_task.add_input(
1948-
tiled_rhs2, (constant(i), dimension(1))
1966+
tiled_rhs2 = rhs2.partition_by_tiling(
1967+
(k_batch_size, tile_shape[1])
19491968
)
19501969

1951-
manual_task.execute()
1952-
1953-
for i in range(0, k_color[0]):
1954-
run_matmul_for_batch(tiled_lhs, tiled_rhs1, tiled_rhs2, i)
1970+
def run_matmul_for_batch(
1971+
tiled_lhs: LogicalStorePartition,
1972+
tiled_rhs1: LogicalStorePartition,
1973+
tiled_rhs2: LogicalStorePartition,
1974+
i: int,
1975+
) -> None:
1976+
manual_task = legate_runtime.create_manual_task(
1977+
self.library, CuPyNumericOpCode.MATMUL, color_shape
1978+
)
1979+
1980+
manual_task.add_output(tiled_lhs)
1981+
manual_task.add_input(tiled_lhs)
1982+
manual_task.add_input(
1983+
tiled_rhs1, (dimension(0), constant(i))
1984+
)
1985+
manual_task.add_input(
1986+
tiled_rhs2, (constant(i), dimension(1))
1987+
)
1988+
1989+
manual_task.execute()
1990+
1991+
for i in range(0, k_color[0]):
1992+
run_matmul_for_batch(
1993+
tiled_lhs, tiled_rhs1, tiled_rhs2, i
1994+
)
19551995

19561996
else:
19571997
assert False
@@ -4216,48 +4256,3 @@ def stencil_hint(
42164256
legate_runtime.prefetch_bloated_instances(
42174257
self.base, low_offsets, high_offsets, False
42184258
)
4219-
4220-
@auto_convert("rhs1_thunk", "rhs2_thunk")
4221-
def ts_matmul(self, rhs1_thunk: Any, rhs2_thunk: Any) -> Any:
4222-
lhs_thunk: NumPyThunk = self
4223-
4224-
# Clear output array
4225-
lhs_thunk.fill(np.array(0, dtype=lhs_thunk.dtype))
4226-
lhs = lhs_thunk.base # type: ignore
4227-
4228-
rhs1 = rhs1_thunk.base
4229-
rhs2 = rhs2_thunk.base
4230-
4231-
m = lhs.shape[0]
4232-
n = lhs.shape[1]
4233-
k = rhs1.shape[1]
4234-
unbatched = 1
4235-
4236-
assert m == rhs1.shape[0]
4237-
assert n == rhs2.shape[1]
4238-
assert k == rhs2.shape[0]
4239-
lhs = lhs.promote(1, k)
4240-
rhs1 = rhs1.promote(2, n)
4241-
rhs2 = rhs2.promote(0, m)
4242-
4243-
task = legate_runtime.create_auto_task(
4244-
self.library, CuPyNumericOpCode.MATMUL
4245-
)
4246-
p_lhs = task.add_reduction(lhs, ReductionOpKind.ADD)
4247-
p_rhs1 = task.add_input(rhs1)
4248-
p_rhs2 = task.add_input(rhs2)
4249-
#
4250-
# specify unbatched matrix multiplication:
4251-
#
4252-
task.add_scalar_arg(unbatched, ty.uint32)
4253-
4254-
task.add_constraint(align(p_lhs, p_rhs1))
4255-
task.add_constraint(align(p_lhs, p_rhs2))
4256-
#
4257-
# additional constraints:
4258-
#
4259-
# task.add_constraint(broadcast(p_rhs1, (0,)))
4260-
# task.add_constraint(broadcast(p_rhs2, (1,)))
4261-
task.add_constraint(broadcast(p_lhs))
4262-
#
4263-
task.execute()

cupynumeric/_thunk/eager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,9 +2147,6 @@ def stencil_hint(
21472147
if self.deferred is not None:
21482148
self.deferred.stencil_hint(low_offsets, high_offsets)
21492149

2150-
def ts_matmul(self, rhs1_thunk: Any, rhs2_thunk: Any) -> Any:
2151-
np.matmul(rhs1_thunk.array, rhs2_thunk.array, out=self.array)
2152-
21532150
def in1d(
21542151
self,
21552152
ar2: Any,

cupynumeric/_thunk/thunk.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,9 +1512,6 @@ def stencil_hint(
15121512
self, low_offsets: tuple[int, ...], high_offsets: tuple[int, ...]
15131513
) -> None: ...
15141514

1515-
@abstractmethod
1516-
def ts_matmul(self, rhs1_thunk: Any, rhs2_thunk: Any) -> Any: ...
1517-
15181515
@abstractmethod
15191516
def in1d(
15201517
self,

cupynumeric/linalg/linalg.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from .._module import dot, empty_like, eye, matmul, ndarray
4040
from .._module.array_rearrange import flip
4141
from .._module.creation_matrices import diag
42-
from .._module.creation_shape import empty, zeros, zeros_like
42+
from .._module.creation_shape import zeros, zeros_like
4343
from .._module.ssc_sorting import argsort
4444
from .._ufunc.math import add, sqrt as _sqrt
4545
from ._exception import LinAlgError
@@ -1573,25 +1573,10 @@ def tssvd(a: ndarray) -> tuple[ndarray, ...]:
15731573
if a.ndim != 2 or a.size <= 1:
15741574
raise ValueError(f"Invalid input shape for tssvd: {a.shape}")
15751575

1576-
m_info = get_machine()
1577-
15781576
# A.T*A:
15791577
#
1580-
# unbatched way (there's a bug resulting in 0-matrix, it seems):
1581-
# {
1582-
m = a.shape[0]
1583-
n = a.shape[1]
1584-
15851578
# TODO: Grammian API:
1586-
#
1587-
a2 = empty(shape=(n, n), dtype=a.dtype)
1588-
ah = a.transpose().conj()
1589-
a2._thunk.ts_matmul(ah._thunk, a._thunk)
1590-
# }
1591-
#
1592-
# batched way (slower, but passes):
1593-
#
1594-
# a2 = matmul(a.transpose().conj(), a)
1579+
a2 = a.transpose().conj() @ a
15951580

15961581
# eigen-vals, eigen-vecs of A.T*A:
15971582
#
@@ -1610,14 +1595,7 @@ def tssvd(a: ndarray) -> tuple[ndarray, ...]:
16101595
# generate index permutation, pi
16111596
# via sort-by-key decreasingly:
16121597
#
1613-
d_indices = zeros(shape=(n,), dtype=np.int64)
1614-
with m_info[0]: # !
1615-
d_indices = argsort(svals)
1616-
#
1617-
# reverse:
1618-
#
1619-
# d_indices = d_indices[::-1] # Error: not implemented
1620-
d_indices = flip(d_indices)
1598+
d_indices = flip(argsort(svals))
16211599

16221600
# V.T:
16231601
#
@@ -1628,14 +1606,7 @@ def tssvd(a: ndarray) -> tuple[ndarray, ...]:
16281606

16291607
# U = A*V*inv(S):
16301608
#
1631-
# B = matmul(ev, Sinv)
1632-
# u = matmul(a, B)
1633-
1634-
B = empty(shape=(n, n), dtype=a.dtype)
1635-
B._thunk.ts_matmul(ev._thunk, Sinv._thunk)
1636-
1637-
u = empty(shape=(m, n), dtype=a.dtype)
1638-
u._thunk.ts_matmul(a._thunk, B._thunk)
1609+
u = a @ (ev @ Sinv)
16391610

16401611
# re-arrange svals decreasingly:
16411612
#
@@ -1644,7 +1615,7 @@ def tssvd(a: ndarray) -> tuple[ndarray, ...]:
16441615
# permute columns of U with pi:
16451616
#
16461617
# u = u[:, d_indices]
1647-
u = matmul(u, eye(u.shape[1])[d_indices].T)
1618+
u = u @ eye(u.shape[1])[d_indices].T
16481619

16491620
# permute rows of V.T with pi:
16501621
#

examples/tssvd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def run_tssvd(m, n, perform_check, timing):
6363
if timing:
6464
print(f"TSSVD elapsed Time: {total:.3f} ms")
6565

66+
return total
67+
6668

6769
if __name__ == "__main__":
6870
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)