ck_tile grouped gemm: more padding#574
Conversation
|
Quick follow-up question: are there certain padding cases/shapes where we should prefer fallback due to the performance penalty of the padded path? |
I looked at this briefly but could not find a config where this would be profitable, at least for bf16. |
| ) | ||
|
|
||
| for o, o_ref in zip(out, out_ref): | ||
| if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4): |
There was a problem hiding this comment.
The test itself is IS_HIP_EXTENSION only
| n_val = unaligned_n if "N" in pad_dim else n_aligned | ||
|
|
||
| total_m = sum(m_vals) | ||
| os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" |
There was a problem hiding this comment.
nit: better use monkeypath to make sure the envs are cleared if tests fails
| # M: not multiples of tile (256), varies per group. | ||
| # N: multiple of 16 but not multiple of tile (128). | ||
| unaligned_k = 2016 | ||
| unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180] |
There was a problem hiding this comment.
I think z should be derived as len of unaligned_m, or it should be asserted that they are equal
There was a problem hiding this comment.
But why separate translation units are needed for every ck_tile_grouped_gemm_fp16_dispatch_* and why the methods themselves are needed instead of directly calling ck_tile_grouped_gemm_fp16_dispatch_layout<>() ?
Description
Enabling padding always causes a significant (~15%) reduction in speed, so only enable it when necessary.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: