Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions megablocks/backend/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def _padded_copy(
# need to reduce the result. Using atomics is slow, so we
# do the reduce step in a second kernel.
offset = index_a // TOP_K if A_TO_B else index_a
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
a += tl.multiple_of(offset.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS)
b += tl.multiple_of(index_b.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS)
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)

# Load the scale, if requested.
Expand Down Expand Up @@ -258,8 +258,8 @@ def _padded_copy_wgrad(

# Offset the input and output pointers.
wgrad += index_out
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
grad += tl.multiple_of((index_out // TOP_K).to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS)
x += tl.multiple_of(index_x.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS)
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)

acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
Expand Down Expand Up @@ -365,8 +365,8 @@ def _binned_copy(
# need to reduce the result. Using atomics is slow, so we
# do the reduce step in a second kernel.
offset = index_a // TOP_K if A_TO_B else index_a
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
a += tl.multiple_of(offset.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS)
b += tl.multiple_of(index_b.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS)
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)

# Load the scale, if requested.
Expand Down Expand Up @@ -500,8 +500,8 @@ def _binned_copy_wgrad(

# Offset the input and output pointers.
wgrad += index_out
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
grad += tl.multiple_of((index_out // TOP_K).to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS)
x += tl.multiple_of(index_x.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS)
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)

acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
Expand Down