|
5 | 5 | import jax |
6 | 6 | import jax.numpy as jnp |
7 | 7 | from jax.experimental import pallas as pl |
| 8 | +from jax.experimental.pallas import gpu as plgpu |
8 | 9 |
|
9 | 10 | from .mhsa import mhsa_kernel, reference_mhsa_kernel |
10 | 11 | from .mhsea import mhsea_kernel, reference_mhsea_kernel |
@@ -53,8 +54,8 @@ def mhsa_forward( |
53 | 54 | out_shape=jax.ShapeDtypeStruct( |
54 | 55 | shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype |
55 | 56 | ), |
56 | | - compiler_params=dict( |
57 | | - triton=dict(num_warps=num_warps, num_stages=num_stages) |
| 57 | + compiler_params=plgpu.TritonCompilerParams( |
| 58 | + num_warps=num_warps, num_stages=num_stages |
58 | 59 | ), |
59 | 60 | debug=False, |
60 | 61 | interpret=interpret, |
@@ -113,8 +114,8 @@ def mhsa_backward( |
113 | 114 | shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype |
114 | 115 | ), |
115 | 116 | ], |
116 | | - compiler_params=dict( |
117 | | - triton=dict(num_warps=num_warps, num_stages=num_stages) |
| 117 | + compiler_params=plgpu.TritonCompilerParams( |
| 118 | + num_warps=num_warps, num_stages=num_stages |
118 | 119 | ), |
119 | 120 | debug=False, |
120 | 121 | interpret=interpret, |
@@ -268,8 +269,8 @@ def mhsea_forward( |
268 | 269 | shape=(batch_len, seq_len, num_heads), dtype=v.dtype |
269 | 270 | ), |
270 | 271 | ], |
271 | | - compiler_params=dict( |
272 | | - triton=dict(num_warps=num_warps, num_stages=num_stages) |
| 272 | + compiler_params=plgpu.TritonCompilerParams( |
| 273 | + num_warps=num_warps, num_stages=num_stages |
273 | 274 | ), |
274 | 275 | debug=False, |
275 | 276 | interpret=interpret, |
@@ -372,8 +373,8 @@ def mhsea_backward( |
372 | 373 | shape=(batch_len, seq_len, num_heads, seq_len), dtype=e.dtype |
373 | 374 | ), |
374 | 375 | ], |
375 | | - compiler_params=dict( |
376 | | - triton=dict(num_warps=num_warps, num_stages=num_stages) |
| 376 | + compiler_params=plgpu.TritonCompilerParams( |
| 377 | + num_warps=num_warps, num_stages=num_stages |
377 | 378 | ), |
378 | 379 | debug=False, |
379 | 380 | interpret=interpret, |
@@ -433,8 +434,8 @@ def mhsea_backward( |
433 | 434 | shape=(batch_len, seq_len, num_heads, head_len), dtype=v.dtype |
434 | 435 | ), |
435 | 436 | ], |
436 | | - compiler_params=dict( |
437 | | - triton=dict(num_warps=num_warps, num_stages=num_stages) |
| 437 | + compiler_params=plgpu.TritonCompilerParams( |
| 438 | + num_warps=num_warps, num_stages=num_stages |
438 | 439 | ), |
439 | 440 | debug=False, |
440 | 441 | interpret=interpret, |
|
0 commit comments