diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index b584ceed..f553038b 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -82,8 +82,8 @@ def _padded_copy( # need to reduce the result. Using atomics is slow, so we # do the reduce step in a second kernel. offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + a += tl.multiple_of(offset.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) # Load the scale, if requested. @@ -258,8 +258,8 @@ def _padded_copy_wgrad( # Offset the input and output pointers. wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + grad += tl.multiple_of((index_out // TOP_K).to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) acc = tl.zeros((BLOCK_X,), dtype=tl.float32) @@ -365,8 +365,8 @@ def _binned_copy( # need to reduce the result. Using atomics is slow, so we # do the reduce step in a second kernel. offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + a += tl.multiple_of(offset.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) # Load the scale, if requested. @@ -500,8 +500,8 @@ def _binned_copy_wgrad( # Offset the input and output pointers. wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + grad += tl.multiple_of((index_out // TOP_K).to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) acc = tl.zeros((BLOCK_X,), dtype=tl.float32)