Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/core/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,6 @@ void ExecutionPlan::Impl::setupOperation(const nlohmann::json& op, Operation& op
throw Error("Invalid channel type", ErrorCode::ExecutorError);
};

uint32_t tbId = 0;
uint32_t tbgSize = 1;

operation.type = static_cast<mscclpp::OperationType>(getOpType(op["name"]));
if (op.contains("channel_type")) {
operation.channelType = convertToChannelType(op["channel_type"]);
Expand All @@ -509,12 +506,15 @@ void ExecutionPlan::Impl::setupOperation(const nlohmann::json& op, Operation& op
operation.channelIndexes[i] = op["channel_ids"][i];
}
}
if (op.contains("tbg_info")) {
tbId = op["tbg_info"]["tb_id"];
tbgSize = op["tbg_info"]["tbg_size"];
}
if (op.contains("src_buff")) {
operation.nInputs = op["src_buff"].size();
if (op.contains("tbg_info")) {
operation.tbId = op["tbg_info"]["tb_id"];
operation.tbgSize = op["tbg_info"]["tbg_size"];
} else {
operation.tbId = 0;
operation.tbgSize = 1;
}
for (int i = 0; i < operation.nInputs; i++) {
auto& buff = op["src_buff"][i];
size_t constOffset = 0;
Expand All @@ -537,14 +537,21 @@ void ExecutionPlan::Impl::setupOperation(const nlohmann::json& op, Operation& op
}
size_t inputOffset = this->getOffset(this->inputSize, this->outputSize, buff["index"], bufferType) + constOffset;
size_t inputBufferSize = this->getBufferSize(this->inputSize, this->outputSize, buff["index"], buff["size"]);
inputOffset += calcOffset(inputBufferSize, tbId, tbgSize);
inputBufferSize = calcSize(inputBufferSize, tbId, tbgSize);
inputOffset += calcOffset(inputBufferSize, 0, 1);
inputBufferSize = calcSize(inputBufferSize, 0, 1);
operation.inputOffsets[i] = inputOffset;
operation.inputBufferSizes[i] = inputBufferSize;
}
}
if (op.contains("dst_buff")) {
operation.nOutputs = op["dst_buff"].size();
if (op.contains("tbg_info")) {
operation.tbId = op["tbg_info"]["tb_id"];
operation.tbgSize = op["tbg_info"]["tbg_size"];
} else {
operation.tbId = 0;
operation.tbgSize = 1;
}
for (int i = 0; i < operation.nOutputs; i++) {
auto& buff = op["dst_buff"][i];
size_t constOffset = 0;
Expand All @@ -567,8 +574,8 @@ void ExecutionPlan::Impl::setupOperation(const nlohmann::json& op, Operation& op
}
size_t outputOffset = this->getOffset(this->inputSize, this->outputSize, buff["index"], bufferType) + constOffset;
size_t outputBufferSize = this->getBufferSize(this->inputSize, this->outputSize, buff["index"], buff["size"]);
outputOffset += calcOffset(outputBufferSize, tbId, tbgSize);
outputBufferSize = calcSize(outputBufferSize, tbId, tbgSize);
outputOffset += calcOffset(outputBufferSize, 0, 1);
outputBufferSize = calcSize(outputBufferSize, 0, 1);
operation.outputOffsets[i] = outputOffset;
operation.outputBufferSizes[i] = outputBufferSize;
}
Expand Down
3 changes: 3 additions & 0 deletions src/core/include/execution_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ struct Operation {
uint8_t nChannels;
uint8_t nInputs;
uint8_t nOutputs;

uint8_t tbId;
uint8_t tbgSize;
};
struct {
uint32_t unitSize;
Expand Down
Loading
Loading