Skip to content

Commit 805212c

Browse files
issue/143 fix bench script, worker cleanup, compiler initial input
1 parent 1e739f2 commit 805212c

3 files changed

Lines changed: 58 additions & 14 deletions

File tree

csrc/engine/compiler/paged_compiler.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
#include "paged_compiler.hpp"
22

3+
namespace {
4+
// Todo: replace with Tensor::zeros when it is available
5+
inline void set_zeros(infinicore::Tensor &tensor) {
6+
std::vector<uint8_t> zeros(tensor->nbytes(), 0);
7+
infinicore::context::memcpyH2D(tensor->data(), zeros.data(), tensor->nbytes(), false);
8+
}
9+
10+
} // namespace
311
namespace infinilm::engine {
412
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
513
: GraphCompiler(model, barrier) {
@@ -27,22 +35,28 @@ void PagedCompiler::compile() {
2735
compiled_map_decode_.clear();
2836
block_tables_holder_ = infinicore::Tensor::empty(
2937
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice());
38+
set_zeros(block_tables_holder_);
3039
for (size_t b : decode_batch_sizes_) {
3140
size_t block_per_req = nblocks / b;
3241
InfinilmModel::Input input;
3342
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
3443
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
3544
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
45+
set_zeros(input.input_ids.value());
46+
set_zeros(input.position_ids.value());
47+
set_zeros(input.total_sequence_lengths.value());
3648
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
3749
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
3850
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice());
51+
set_zeros(input.input_offsets.value());
3952
std::vector<int64_t> input_offsets_vec(b + 1, 0);
4053
for (size_t i = 0; i <= b; i++) {
4154
input_offsets_vec[i] = i;
4255
}
4356
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false);
4457
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
4558
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
59+
set_zeros(input.slot_mapping.value());
4660

4761
barrier_->wait();
4862
infinicore::context::startGraphRecording();

csrc/engine/rank_worker.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,12 @@ void RankWorker::thread_loop() {
245245
try {
246246
model_->load_parameter(local_param_name, local_param);
247247
} catch (const std::exception &e) {
248-
// convert exceptions to a safe behavior: set should_exit_ and notify caller
249-
std::lock_guard<std::mutex> lk(mutex_);
250-
should_exit_ = true;
251-
job_done_ = true;
248+
{
249+
std::lock_guard<std::mutex> lk(mutex_);
250+
should_exit_ = true;
251+
job_done_ = true;
252+
}
252253
cv_.notify_all();
253-
// rethrow so the thread can be joined and caller sees an error if desired (optional)
254254
spdlog::error("[{}] exception during load_parameter_: {}\n", info(), e.what());
255255
break;
256256
}
@@ -320,9 +320,11 @@ void RankWorker::thread_loop() {
320320
cv_.notify_all();
321321

322322
} catch (const std::exception &e) {
323-
std::lock_guard<std::mutex> lk(mutex_);
324-
should_exit_ = true;
325-
job_done_ = true;
323+
{
324+
std::lock_guard<std::mutex> lk(mutex_);
325+
should_exit_ = true;
326+
job_done_ = true;
327+
}
326328
cv_.notify_all();
327329
spdlog::error("[{}] exception during forward: {}\n", info(), e.what());
328330
break;
@@ -337,9 +339,11 @@ void RankWorker::thread_loop() {
337339
cv_.notify_all();
338340

339341
} catch (const std::exception &e) {
340-
std::lock_guard<std::mutex> lk(mutex_);
341-
should_exit_ = true;
342-
job_done_ = true;
342+
{
343+
std::lock_guard<std::mutex> lk(mutex_);
344+
should_exit_ = true;
345+
job_done_ = true;
346+
}
343347
cv_.notify_all();
344348
spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
345349
break;
@@ -356,9 +360,11 @@ void RankWorker::thread_loop() {
356360
cv_.notify_all();
357361

358362
} catch (const std::exception &e) {
359-
std::lock_guard<std::mutex> lk(mutex_);
360-
should_exit_ = true;
361-
job_done_ = true;
363+
{
364+
std::lock_guard<std::mutex> lk(mutex_);
365+
should_exit_ = true;
366+
job_done_ = true;
367+
}
362368
cv_.notify_all();
363369
spdlog::error("[{}] exception during compile: {}\n", info(), e.what());
364370
break;
@@ -368,6 +374,9 @@ void RankWorker::thread_loop() {
368374
// Shouldn't reach here (no-op)
369375
}
370376
} // while
377+
378+
// Some clean up should be done before exiting the thread
379+
compiler_.reset();
371380
} catch (const std::exception &e) {
372381
// Top-level exception: ensure any waiters are woken and the thread exits cleanly.
373382
{

examples/bench.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,21 @@ def get_args():
137137
action="store_true",
138138
help="Run nvidia test",
139139
)
140+
parser.add_argument(
141+
"--metax",
142+
action="store_true",
143+
help="Run metax test",
144+
)
145+
parser.add_argument(
146+
"--moore",
147+
action="store_true",
148+
help="Run moore test",
149+
)
150+
parser.add_argument(
151+
"--iluvatar",
152+
action="store_true",
153+
help="Run iluvatar test",
154+
)
140155
parser.add_argument(
141156
"--cambricon",
142157
action="store_true",
@@ -299,6 +314,12 @@ def run(
299314
device_str = "cpu"
300315
elif args.nvidia:
301316
device_str = "cuda"
317+
elif args.metax:
318+
device_str = "cuda"
319+
elif args.moore:
320+
device_str = "musa"
321+
elif args.iluvatar:
322+
device_str = "cuda"
302323
elif args.cambricon:
303324
device_str = "mlu"
304325
else:

0 commit comments

Comments
 (0)