From 95d86fc7a3e27efd5ad52a6bffca7c9fa95e2ea1 Mon Sep 17 00:00:00 2001 From: Jinghe Zhang <1132764130@qq.com> Date: Tue, 19 May 2026 19:50:33 +0800 Subject: [PATCH 1/2] Modify llama tokenizer and pymllm. --- examples/llama_qnn_aot/aot_run.cpp | 43 +-- mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp | 3 + mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp | 18 +- mllm/models/llama/tokenization_llama.hpp | 245 ++++++++++++++++++ .../transformers/llama/modeling_llama.py | 20 +- .../qualcomm/transformers/llama/runner.py | 35 ++- .../transformers/qwen2/modeling_qwen2.py | 20 +- 7 files changed, 334 insertions(+), 50 deletions(-) create mode 100644 mllm/models/llama/tokenization_llama.hpp diff --git a/examples/llama_qnn_aot/aot_run.cpp b/examples/llama_qnn_aot/aot_run.cpp index c19183533..4847954e0 100644 --- a/examples/llama_qnn_aot/aot_run.cpp +++ b/examples/llama_qnn_aot/aot_run.cpp @@ -4,8 +4,7 @@ #include #include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" #include "configuration_llama3.hpp" -#include "mllm/models/llama/tokenization_tiny_llama.hpp" -#include "mllm/models/qwen3/tokenization_qwen3.hpp" +#include "mllm/models/llama/tokenization_llama.hpp" using mllm::Argparse; using namespace mllm::qnn::aot; // NOLINT @@ -16,8 +15,8 @@ MLLM_MAIN({ auto& tokenizer_path = Argparse::add("-t|--tokenizer").help("Tokenizer path").def("tokenizer.json"); auto& config_path = Argparse::add("-c|--config").help("Config path").required(true); auto& ar_len = Argparse::add("--ar_len").help("Autoregressive length (chunk size)").def(128); - auto& seq_len = Argparse::add("--seq_len").help("Input sequence length").def(800); - auto& gen_len = Argparse::add("--gen_len").help("Generate token length").def(32); + // auto& seq_len = Argparse::add("--seq_len").help("Input sequence length").def(800); + // auto& gen_len = Argparse::add("--gen_len").help("Generate token length").def(32); Argparse::parse(argc, argv); @@ -37,22 +36,36 @@ MLLM_MAIN({ config.vocab_size = llama_cfg.vocab_size; config.context_len = 1024; config.ar_len = ar_len.get(); + config.type = "llama3"; // Note: Using Qwen3 tokenizer as a placeholder. // For production use, you should implement a Llama3Tokenizer or use // the appropriate tokenizer for your model. - auto tokenizer = mllm::models::llama::TinyLlamaTokenizer(tokenizer_path.get()); + auto tokenizer = mllm::models::llama::LlamaTokenizer(tokenizer_path.get()); - auto input_tensor = tokenizer.convertMessage({{ - .role = "user", - .content = "hello", - }}); + // auto input_tensor = tokenizer.convertMessage({{ + // .role = "user", + // .content = "hello", + // }}); - input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); + // input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); - // DBG: - mllm::print(input_tensor["sequence"].shape()); - mllm::print(input_tensor["sequence"]); + // // DBG: + // mllm::print(input_tensor["sequence"].shape()); + // mllm::print(input_tensor["sequence"]); + + // Runner runner(config, &tokenizer); + // if (!runner.load()) { + // std::cerr << "Failed to load model\n"; + // return 1; + // } + + + std::string prompt_text; + fmt::print("💬 Prompt text (or 'exit/quit'): "); + std::getline(std::cin, prompt_text); + + auto input_tensor = tokenizer.convertMessage({{.role = "user", .content = prompt_text}}); Runner runner(config, &tokenizer); if (!runner.load()) { @@ -60,8 +73,8 @@ MLLM_MAIN({ return 1; } - runner.generate( - input_tensor["sequence"], gen_len.get(), [](const std::string& token) { std::cout << token << std::flush; }, true); + runner.generate(input_tensor["sequence"], config.context_len, + [](const std::string& token) { std::cout << token << std::flush; }); std::cout << "\n"; return 0; diff --git a/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp b/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp index 8943d6cec..3486cc3f9 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp @@ -3,11 +3,14 @@ #pragma once +#include #include "mllm/core/DataTypes.hpp" namespace mllm::qnn::aot { struct QnnAOTConfig { + std::string type = "qwen3"; + int num_layers = 28; int num_heads = 12; int head_dim = 128; diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp index ae1fafa29..68d002c67 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp @@ -46,8 +46,22 @@ bool Runner::load() { // init token generator(decode) // TODO: EOS IDs auto eos_ids = std::make_unique>(); - eos_ids->insert(151643); - eos_ids->insert(151645); + // eos_ids->insert(151643); + // eos_ids->insert(151645); + + // Dynamically determine the currently loaded model based on the model name. + if (config_.type == "llama3") { + eos_ids->insert(128001); // <|end_of_text|> + eos_ids->insert(128008); // <|eom_id|> + eos_ids->insert(128009); // <|eot_id|> + } else if (config_.type == "qwen2"){ + eos_ids->insert(151643); + eos_ids->insert(151645); + } else{ + // qwen3 + eos_ids->insert(151643); + eos_ids->insert(151645); + } token_generator_ = std::make_unique>(tokenizer_, kv_manager_.get(), std::move(eos_ids), config_); diff --git a/mllm/models/llama/tokenization_llama.hpp b/mllm/models/llama/tokenization_llama.hpp new file mode 100644 index 000000000..ad5f2ca15 --- /dev/null +++ b/mllm/models/llama/tokenization_llama.hpp @@ -0,0 +1,245 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" + +namespace mllm::models::llama { + +// 适配 Llama 3 的正则切分逻辑 +inline bool llama3TokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + // 1. 匹配缩写 + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d", L"'S", L"'T", L"'RE", L"'VE", L"'M", L"'LL", L"'D"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + // 2. 匹配字母 + { + size_t original_pos = pos; + matched.clear(); + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + } + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } + pos = original_pos; + } + + // 3. 匹配数字 + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + // 4. 匹配符号 + { + size_t start = pos; + if (str[pos] == L' ') { ++pos; } + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { ++pos; } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])); + matched = str.substr(start, pos - start); + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } + pos = start; + } + + // 5. 匹配空格 + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline void llama3Regex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (llama3TokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } +} + +struct LlamaMessage { + std::string role; + std::string content; +}; + +// 恢复继承自原版的 AutoTokenizer,满足 aot_run.cpp 的要求 +class LlamaTokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit LlamaTokenizer(const std::string& file_path, bool add_bos = true) : add_bos_(add_bos) { + preprocessor::initLocal(); + // 恢复内建的字典映射机制 + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + + bpe_.initFromSentencePieceJson(file_path); + + special_tokens_trie_.add(L"<|begin_of_text|>"); + special_tokens_trie_.add(L"<|end_of_text|>"); + special_tokens_trie_.add(L"<|start_header_id|>"); + special_tokens_trie_.add(L"<|end_header_id|>"); + special_tokens_trie_.add(L"<|eot_id|>"); + } + + std::string getSystemPromptPrefix() { + std::time_t t = std::time(nullptr); + std::tm tm_ = *std::localtime(&t); + std::ostringstream oss; + oss << std::put_time(&tm_, "%d %b %Y"); + return "Cutting Knowledge Date: December 2023\nToday Date: " + oss.str() + "\n\n"; + } + + inline std::string applyChatTemplate(const std::vector& messages, bool add_generation_prompt = true) { + std::string result = ""; + if (add_bos_) result += "<|begin_of_text|>"; + for (const auto& msg : messages) { + std::string content = msg.content; + if (msg.role == "system") content = getSystemPromptPrefix() + content; + result += "<|start_header_id|>" + msg.role + "<|end_header_id|>\n\n" + content + "<|eot_id|>"; + } + if (add_generation_prompt) result += "<|start_header_id|>assistant<|end_header_id|>\n\n"; + return result; + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + llama3Regex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + // 执行字节映射 + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + auto bpe_ts = bpe_._bpe(mapped_str); + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + } + return ret; + } + + std::vector tokenize(const std::string& str) override { + std::string processed_str = str; + bool text_has_bos = (processed_str.find("<|begin_of_text|>") == 0); + if (add_bos_ && !text_has_bos) { + processed_str = "<|begin_of_text|>" + processed_str; + } + + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(processed_str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { + if (bytes_2_unicode_dict_inverse_.count(c)) { + utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); + } else { + return str; + } + } + return mllm::preprocessor::utf8string2WideString(utf_8_str); + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = Tensor::empty({1, (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("llama-tokenizer-i0") + .alloc(); + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + return ret; + } + + // 供 test_c.cpp 调用的便捷接口 + std::vector encode(const std::string& str) { + auto sub_tokens = tokenize(str); + std::vector ret; + for (auto& token : sub_tokens) { ret.emplace_back(bpe_._lookup_vocab(token)); } + return ret; + } + + std::string decode(const std::vector& ids) { + std::string ret; + for (auto& each_id : ids) { + auto wstr = detokenize(each_id); + ret += mllm::preprocessor::wideString2Utf8String(wstr); + } + return ret; + } + + ARGenerationOutputPast convertMessage(const std::vector& messages) { + auto applied_string = applyChatTemplate(messages, true); + auto sequence_str = tokenize(applied_string); + std::vector ids; + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + Tensor sequence = Tensor::empty({1, (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kNormal) + .setName("llama-tokenizer-i0") + .alloc(); + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return { + {"sequence", sequence}, + }; + } + + private: + bool add_bos_ = true; + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; +}; + +} // namespace mllm::models::llama \ No newline at end of file diff --git a/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py index 6b65f34b9..8ebf0afcd 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py @@ -302,8 +302,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int): # QDQ self.q_proj_input_qdq = ActivationQDQ(bits=16) - self.k_proj_input_qdq = ActivationQDQ(bits=16) - self.v_proj_input_qdq = ActivationQDQ(bits=16) + # self.k_proj_input_qdq = ActivationQDQ(bits=16) + # self.v_proj_input_qdq = ActivationQDQ(bits=16) self.q_proj_output_qdq = ActivationQDQ(bits=16) self.k_proj_output_qdq = ActivationQDQ(bits=16) @@ -336,13 +336,13 @@ def __init__(self, config: LlamaConfig, layer_idx: int): ) self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) self.k_rope_concat_observer.add_observer( - self.k_proj_input_qdq.fake_quant.activation_post_process + self.k_proj_output_qdq.fake_quant.activation_post_process ) self.k_rope_concat_observer.add_observer( self.k_rope_neg_half_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( - self.q_proj_input_qdq.fake_quant.activation_post_process + self.q_proj_output_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( self.q_rope_neg_half_qdq.fake_quant.activation_post_process @@ -384,12 +384,12 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = self.q_proj_output_qdq(query_states) - hidden_states_k = self.k_proj_input_qdq(hidden_states) - key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2) + # hidden_states_k = self.k_proj_input_qdq(hidden_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj_output_qdq(key_states) - hidden_states_v = self.v_proj_input_qdq(hidden_states) - value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2) + # hidden_states_v = self.v_proj_input_qdq(hidden_states) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings cos = cos.unsqueeze(1) @@ -399,7 +399,7 @@ def forward( + self.q_rope_mul_1_output_qdq( rotate_half( query_states, - self.q_proj_input_qdq.fake_quant.activation_post_process, + self.q_proj_output_qdq.fake_quant.activation_post_process, self.q_rope_neg_half_qdq, self.q_rope_concat_observer, ) @@ -411,7 +411,7 @@ def forward( + self.k_rope_mul_1_output_qdq( rotate_half( key_states, - self.k_proj_input_qdq.fake_quant.activation_post_process, + self.k_proj_output_qdq.fake_quant.activation_post_process, self.k_rope_neg_half_qdq, self.k_rope_concat_observer, ) diff --git a/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py index 45af95f8f..9aa1a4f73 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py @@ -251,6 +251,12 @@ def compile(self): print("Compile done.") def infer(self, prompt: str): + messages = [{"role": "user", "content": prompt}] + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) # Llama models typically don't use chat templates, so we tokenize directly model_inputs = self.tokenizer([prompt], return_tensors="pt").to( self.model.device @@ -308,12 +314,13 @@ def calibrate(self, num_samples=64, max_seq_length=512): for entry in dataset: if samples_processed >= num_samples: break - - if len(entry["text"].strip()) < 1024: + + text = entry["text"].strip() + if len(text) < 50: continue # Llama models typically don't use chat templates - text = entry["text"] + # text = entry["text"] model_inputs = self.tokenizer( [text], return_tensors="pt", @@ -322,16 +329,18 @@ def calibrate(self, num_samples=64, max_seq_length=512): padding=False, ).to(self.model.device) - # Only need Prefill stage: directly call forward - # This will trigger observer update statistics in ActivationQDQ - self.model.generate( - **model_inputs, - max_new_tokens=1, - do_sample=False, - temperature=None, - top_p=None, - top_k=None, - ) + self.model(**model_inputs) + + # # Only need Prefill stage: directly call forward + # # This will trigger observer update statistics in ActivationQDQ + # self.model.generate( + # **model_inputs, + # max_new_tokens=1, + # do_sample=False, + # temperature=None, + # top_p=None, + # top_k=None, + # ) samples_processed += 1 pbar.update(1) diff --git a/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py b/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py index a43d8b7ea..f8ad9ec56 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py @@ -186,12 +186,12 @@ def __init__(self, config: Qwen2Config, layer_idx: int): # QDQ self.q_proj_input_qdq = ActivationQDQ(bits=16) - self.k_proj_input_qdq = ActivationQDQ(bits=16) + # self.k_proj_input_qdq = ActivationQDQ(bits=16) self.q_proj_output_qdq = ActivationQDQ(bits=16) self.k_proj_output_qdq = ActivationQDQ(bits=16) - self.v_proj_input_qdq = ActivationQDQ(bits=16) + # self.v_proj_input_qdq = ActivationQDQ(bits=16) self.q_rope_mul_0_output_qdq = ActivationQDQ(bits=16) self.q_rope_mul_1_output_qdq = ActivationQDQ(bits=16) self.q_rope_add_0_output_qdq = ActivationQDQ(bits=16) @@ -220,13 +220,13 @@ def __init__(self, config: Qwen2Config, layer_idx: int): ) self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) self.k_rope_concat_observer.add_observer( - self.k_proj_input_qdq.fake_quant.activation_post_process + self.k_proj_output_qdq.fake_quant.activation_post_process ) self.k_rope_concat_observer.add_observer( self.k_rope_neg_half_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( - self.q_proj_input_qdq.fake_quant.activation_post_process + self.q_proj_output_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( self.q_rope_neg_half_qdq.fake_quant.activation_post_process @@ -268,12 +268,12 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = self.q_proj_output_qdq(query_states) - hidden_states_k = self.k_proj_input_qdq(hidden_states) - key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2) + # hidden_states_k = self.k_proj_input_qdq(hidden_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj_output_qdq(key_states) - hidden_states_v = self.v_proj_input_qdq(hidden_states) - value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2) + # hidden_states_v = self.v_proj_input_qdq(hidden_states) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings cos = cos.unsqueeze(1) @@ -283,7 +283,7 @@ def forward( + self.q_rope_mul_1_output_qdq( rotate_half( query_states, - self.q_proj_input_qdq.fake_quant.activation_post_process, + self.q_proj_output_qdq.fake_quant.activation_post_process, self.q_rope_neg_half_qdq, self.q_rope_concat_observer, ) @@ -295,7 +295,7 @@ def forward( + self.k_rope_mul_1_output_qdq( rotate_half( key_states, - self.k_proj_input_qdq.fake_quant.activation_post_process, + self.k_proj_output_qdq.fake_quant.activation_post_process, self.k_rope_neg_half_qdq, self.k_rope_concat_observer, ) From c0d946721003afb6673922326ba92bb93df00708 Mon Sep 17 00:00:00 2001 From: Jinghe Zhang <1132764130@qq.com> Date: Sat, 6 Jun 2026 21:09:36 +0800 Subject: [PATCH 2/2] Format. --- examples/llama_qnn_aot/aot_run.cpp | 1 - mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp | 2 +- mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp | 20 +- mllm/models/llama/tokenization_llama.hpp | 186 ++++++++---------- .../transformers/llama/modeling_llama.py | 28 ++- .../qualcomm/transformers/llama/runner.py | 9 +- 6 files changed, 118 insertions(+), 128 deletions(-) diff --git a/examples/llama_qnn_aot/aot_run.cpp b/examples/llama_qnn_aot/aot_run.cpp index 4847954e0..fc12fe468 100644 --- a/examples/llama_qnn_aot/aot_run.cpp +++ b/examples/llama_qnn_aot/aot_run.cpp @@ -60,7 +60,6 @@ MLLM_MAIN({ // return 1; // } - std::string prompt_text; fmt::print("💬 Prompt text (or 'exit/quit'): "); std::getline(std::cin, prompt_text); diff --git a/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp b/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp index 3486cc3f9..edf7b565a 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp @@ -10,7 +10,7 @@ namespace mllm::qnn::aot { struct QnnAOTConfig { std::string type = "qwen3"; - + int num_layers = 28; int num_heads = 12; int head_dim = 128; diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp index 68d002c67..8d10e98e7 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp @@ -51,16 +51,16 @@ bool Runner::load() { // Dynamically determine the currently loaded model based on the model name. if (config_.type == "llama3") { - eos_ids->insert(128001); // <|end_of_text|> - eos_ids->insert(128008); // <|eom_id|> - eos_ids->insert(128009); // <|eot_id|> - } else if (config_.type == "qwen2"){ - eos_ids->insert(151643); - eos_ids->insert(151645); - } else{ - // qwen3 - eos_ids->insert(151643); - eos_ids->insert(151645); + eos_ids->insert(128001); // <|end_of_text|> + eos_ids->insert(128008); // <|eom_id|> + eos_ids->insert(128009); // <|eot_id|> + } else if (config_.type == "qwen2") { + eos_ids->insert(151643); + eos_ids->insert(151645); + } else { + // qwen3 + eos_ids->insert(151643); + eos_ids->insert(151645); } token_generator_ = std::make_unique>(tokenizer_, kv_manager_.get(), std::move(eos_ids), config_); diff --git a/mllm/models/llama/tokenization_llama.hpp b/mllm/models/llama/tokenization_llama.hpp index ad5f2ca15..9a0bc46a1 100644 --- a/mllm/models/llama/tokenization_llama.hpp +++ b/mllm/models/llama/tokenization_llama.hpp @@ -16,83 +16,81 @@ namespace mllm::models::llama { -// 适配 Llama 3 的正则切分逻辑 inline bool llama3TokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { - if (pos >= str.size()) return false; - - // 1. 匹配缩写 - static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d", L"'S", L"'T", L"'RE", L"'VE", L"'M", L"'LL", L"'D"}; - for (const auto& contraction : contractions) { - if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { - matched = contraction; - pos += contraction.size(); - return true; - } + if (pos >= str.size()) return false; + + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d", + L"'S", L"'T", L"'RE", L"'VE", L"'M", L"'LL", L"'D"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; } + } - // 2. 匹配字母 - { - size_t original_pos = pos; - matched.clear(); - if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { - matched += str[pos]; - ++pos; - } - if (pos < str.size() && preprocessor::isLetter(str[pos])) { - do { - matched += str[pos]; - ++pos; - } while (pos < str.size() && preprocessor::isLetter(str[pos])); - return true; - } - pos = original_pos; + { + size_t original_pos = pos; + matched.clear(); + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; } - - // 3. 匹配数字 - if (preprocessor::isDigit(str[pos])) { - matched = str.substr(pos, 1); + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; ++pos; - return true; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; } + pos = original_pos; + } - // 4. 匹配符号 - { - size_t start = pos; - if (str[pos] == L' ') { ++pos; } - if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { - do { ++pos; } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])); - matched = str.substr(start, pos - start); - while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { - matched += str[pos]; - ++pos; - } - return true; - } - pos = start; - } + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } - // 5. 匹配空格 - if (std::iswspace(str[pos])) { - size_t start = pos; - while (pos < str.size() && std::iswspace(str[pos])) ++pos; - matched = str.substr(start, pos - start); - return true; + { + size_t start = pos; + if (str[pos] == L' ') { ++pos; } + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + matched = str.substr(start, pos - start); + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; } + pos = start; + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } - return false; + return false; } inline void llama3Regex(const std::string& str, std::vector& splitted) { - auto w_string = preprocessor::utf8string2WideString(str); - size_t pos = 0; - while (pos < w_string.size()) { - std::wstring matched; - if (llama3TokenizerMatchPattern(w_string, pos, matched)) { - splitted.push_back(matched); - } else { - ++pos; - } + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (llama3TokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; } + } } struct LlamaMessage { @@ -100,15 +98,13 @@ struct LlamaMessage { std::string content; }; -// 恢复继承自原版的 AutoTokenizer,满足 aot_run.cpp 的要求 class LlamaTokenizer final : public mllm::preprocessor::AutoTokenizer { public: explicit LlamaTokenizer(const std::string& file_path, bool add_bos = true) : add_bos_(add_bos) { preprocessor::initLocal(); - // 恢复内建的字典映射机制 preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } - + bpe_.initFromSentencePieceJson(file_path); special_tokens_trie_.add(L"<|begin_of_text|>"); @@ -143,12 +139,11 @@ class LlamaTokenizer final : public mllm::preprocessor::AutoTokenizer { std::vector splitted; llama3Regex(str, splitted); for (const auto& s : splitted) { - auto utf_8_str = preprocessor::wideString2Utf8String(s); - std::wstring mapped_str; - // 执行字节映射 - for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } - auto bpe_ts = bpe_._bpe(mapped_str); - for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + auto bpe_ts = bpe_._bpe(mapped_str); + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } } return ret; } @@ -156,19 +151,17 @@ class LlamaTokenizer final : public mllm::preprocessor::AutoTokenizer { std::vector tokenize(const std::string& str) override { std::string processed_str = str; bool text_has_bos = (processed_str.find("<|begin_of_text|>") == 0); - if (add_bos_ && !text_has_bos) { - processed_str = "<|begin_of_text|>" + processed_str; - } + if (add_bos_ && !text_has_bos) { processed_str = "<|begin_of_text|>" + processed_str; } auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(processed_str)); std::vector all_tokens; for (const auto& token : tokens) { - if (special_tokens_trie_.isSpecialToken(token)) { - all_tokens.emplace_back(token); - continue; - } - auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); - all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); } return all_tokens; } @@ -178,12 +171,12 @@ class LlamaTokenizer final : public mllm::preprocessor::AutoTokenizer { std::wstring detokenize(int64_t pos_idx) override { auto str = _detokenize(pos_idx); std::string utf_8_str; - for (wchar_t c : str) { - if (bytes_2_unicode_dict_inverse_.count(c)) { - utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); - } else { - return str; - } + for (wchar_t c : str) { + if (bytes_2_unicode_dict_inverse_.count(c)) { + utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); + } else { + return str; + } } return mllm::preprocessor::utf8string2WideString(utf_8_str); } @@ -191,16 +184,13 @@ class LlamaTokenizer final : public mllm::preprocessor::AutoTokenizer { Tensor convert2Ids(const std::vector& strs) override { std::vector ids; for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } - Tensor ret = Tensor::empty({1, (int32_t)ids.size()}, kInt64, kCPU) - .setMemType(kExtraInput) - .setName("llama-tokenizer-i0") - .alloc(); + Tensor ret = + Tensor::empty({1, (int32_t)ids.size()}, kInt64, kCPU).setMemType(kExtraInput).setName("llama-tokenizer-i0").alloc(); auto ptr = ret.ptr(); for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } return ret; } - // 供 test_c.cpp 调用的便捷接口 std::vector encode(const std::string& str) { auto sub_tokens = tokenize(str); std::vector ret; @@ -211,8 +201,8 @@ class LlamaTokenizer final : public mllm::preprocessor::AutoTokenizer { std::string decode(const std::vector& ids) { std::string ret; for (auto& each_id : ids) { - auto wstr = detokenize(each_id); - ret += mllm::preprocessor::wideString2Utf8String(wstr); + auto wstr = detokenize(each_id); + ret += mllm::preprocessor::wideString2Utf8String(wstr); } return ret; } @@ -222,11 +212,9 @@ class LlamaTokenizer final : public mllm::preprocessor::AutoTokenizer { auto sequence_str = tokenize(applied_string); std::vector ids; for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } - - Tensor sequence = Tensor::empty({1, (int32_t)ids.size()}, kInt64, kCPU) - .setMemType(kNormal) - .setName("llama-tokenizer-i0") - .alloc(); + + Tensor sequence = + Tensor::empty({1, (int32_t)ids.size()}, kInt64, kCPU).setMemType(kNormal).setName("llama-tokenizer-i0").alloc(); auto ptr = sequence.ptr(); for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } diff --git a/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py index 8ebf0afcd..73e5aaffc 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py @@ -20,9 +20,20 @@ from typing import Callable, Optional, Union import torch +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( + QLinearLPBQ, +) + +# Replace linear, rms_norm with: +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm from torch import nn from torch.nn import functional as F - from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin @@ -40,6 +51,7 @@ ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig from transformers.processing_utils import Unpack from transformers.utils import ( TransformersKwargs, @@ -49,20 +61,6 @@ ) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import check_model_inputs -from transformers.models.llama.configuration_llama import LlamaConfig - -# Replace linear, rms_norm with: -from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( - QLinearLPBQ, -) -from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( - ActivationQDQ, - FixedActivationQDQ, -) -from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver - logger = logging.get_logger(__name__) diff --git a/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py index 9aa1a4f73..38ca7df62 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py @@ -13,7 +13,9 @@ QLinearW8A16_PerChannelSym, ) from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.mobile.backends.qualcomm.transformers.llama.modeling_llama import LlamaForCausalLM +from pymllm.mobile.backends.qualcomm.transformers.llama.modeling_llama import ( + LlamaForCausalLM, +) from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver @@ -194,6 +196,7 @@ def convert_weight(m): if isinstance(m, QEmbedding): m.convert_to_deploy() + def _check_datasets_compatibility(): try: ds_ver = version("datasets") @@ -210,6 +213,7 @@ def _check_datasets_compatibility(): "datasets version. Please use datasets==2.21.0." ) + class LlamaQuantizer: def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): self.tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -298,6 +302,7 @@ def calibrate(self, num_samples=64, max_seq_length=512): # Use streaming=True to download and process on the fly, without downloading the full几十G dataset _check_datasets_compatibility() from modelscope.msdatasets import MsDataset + dataset = MsDataset.load( "modelscope/wikitext", subset_name="wikitext-103-v1", @@ -314,7 +319,7 @@ def calibrate(self, num_samples=64, max_seq_length=512): for entry in dataset: if samples_processed >= num_samples: break - + text = entry["text"].strip() if len(text) < 50: continue