File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -39,7 +39,7 @@ void op_matrix_element(
3939 TensorVec contiguous_inputs (input_count);
4040 std::vector<UmamiInputKey> input_keys (input_count + 1 );
4141 std::vector<UmamiOutputKey> output_keys (output_count);
42- std::vector<void *> input_ptrs (input_count + 1 ), output_ptrs (output_count);
42+ std::vector<void *> input_ptrs (input_count), output_ptrs (output_count + 1 );
4343 for (std::size_t i = 0 ; i < input_count; ++i) {
4444 input_keys[i] = static_cast <UmamiInputKey>(
4545 locals[instruction.input_indices [3 + 2 * i]].index_value ()
@@ -50,8 +50,6 @@ void op_matrix_element(
5050 );
5151 input_ptrs[i] = contiguous_inputs[i].data ();
5252 }
53- input_keys[input_count] = UMAMI_IN_GPU_STREAM;
54- input_ptrs[input_count] = device.stream ();
5553 std::size_t output_offset = 3 + 2 * input_count;
5654 for (std::size_t i = 0 ; i < output_count; ++i) {
5755 output_keys[i] = static_cast <UmamiOutputKey>(
@@ -90,6 +88,8 @@ void op_matrix_element(
9088 }
9189 }
9290 }
91+ output_keys[output_count] = UMAMI_OUT_GPU_STREAM;
92+ output_ptrs[output_count] = device.stream ();
9393 if (me_index == 0xBADCAFE || batch_size == 0 ) {
9494 return ;
9595 }
You can’t perform that action at this time.
0 commit comments