Skip to content

Commit e985873

Browse files
committed
fix: fix llmc weight load bug
1 parent b6aa83c commit e985873

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

example/gpt2/net.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ constexpr int32_t kHeaderFP32Version = 3;
249249

250250
std::unique_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
251251
std::ifstream ifs(filepath, std::ios::binary);
252-
const auto header = ReadSeveralBytesFromIfstream(256, &ifs);
252+
const auto header = ReadSeveralBytesFromIfstream(256 * sizeof(int32_t), &ifs);
253253

254254
const auto magic = BytesToType<uint32_t>(header, 0);
255255
CHECK_EQ(magic, kHeaderMagic);
@@ -271,7 +271,7 @@ std::unique_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
271271

272272
auto state_dict = gpt2->StateDict();
273273
// transformer.wte.weight
274-
// (vocab_size, n_embd) -> padded -> (padded_vocab_size, n_embd)
274+
// (padded_vocab_size, n_embd) -> un_pad -> (vocab_size, n_embd)
275275
auto &transformer_wte_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kWTELayerName,
276276
nn::Embedding::kParamWeightName)];
277277
ifs.read(reinterpret_cast<char *>(transformer_wte_weight->DataPtr()), transformer_wte_weight->SizeInBytes());

0 commit comments

Comments
 (0)