Skip to content

Commit ffb1610

Browse files
committed
gpu umami bugfix
1 parent a01af30 commit ffb1610

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

madspace/src/gpu/runtime.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void op_matrix_element(
3939
TensorVec contiguous_inputs(input_count);
4040
std::vector<UmamiInputKey> input_keys(input_count + 1);
4141
std::vector<UmamiOutputKey> output_keys(output_count);
42-
std::vector<void*> input_ptrs(input_count + 1), output_ptrs(output_count);
42+
std::vector<void*> input_ptrs(input_count), output_ptrs(output_count + 1);
4343
for (std::size_t i = 0; i < input_count; ++i) {
4444
input_keys[i] = static_cast<UmamiInputKey>(
4545
locals[instruction.input_indices[3 + 2 * i]].index_value()
@@ -50,8 +50,6 @@ void op_matrix_element(
5050
);
5151
input_ptrs[i] = contiguous_inputs[i].data();
5252
}
53-
input_keys[input_count] = UMAMI_IN_GPU_STREAM;
54-
input_ptrs[input_count] = device.stream();
5553
std::size_t output_offset = 3 + 2 * input_count;
5654
for (std::size_t i = 0; i < output_count; ++i) {
5755
output_keys[i] = static_cast<UmamiOutputKey>(
@@ -90,6 +88,8 @@ void op_matrix_element(
9088
}
9189
}
9290
}
91+
output_keys[output_count] = UMAMI_OUT_GPU_STREAM;
92+
output_ptrs[output_count] = device.stream();
9393
if (me_index == 0xBADCAFE || batch_size == 0) {
9494
return;
9595
}

0 commit comments

Comments
 (0)