diff --git a/examples/llama_qnn_aot/aot_run.cpp b/examples/llama_qnn_aot/aot_run.cpp index c19183533..fc12fe468 100644 --- a/examples/llama_qnn_aot/aot_run.cpp +++ b/examples/llama_qnn_aot/aot_run.cpp @@ -4,8 +4,7 @@ #include #include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" #include "configuration_llama3.hpp" -#include "mllm/models/llama/tokenization_tiny_llama.hpp" -#include "mllm/models/qwen3/tokenization_qwen3.hpp" +#include "mllm/models/llama/tokenization_llama.hpp" using mllm::Argparse; using namespace mllm::qnn::aot; // NOLINT @@ -16,8 +15,8 @@ MLLM_MAIN({ auto& tokenizer_path = Argparse::add("-t|--tokenizer").help("Tokenizer path").def("tokenizer.json"); auto& config_path = Argparse::add("-c|--config").help("Config path").required(true); auto& ar_len = Argparse::add("--ar_len").help("Autoregressive length (chunk size)").def(128); - auto& seq_len = Argparse::add("--seq_len").help("Input sequence length").def(800); - auto& gen_len = Argparse::add("--gen_len").help("Generate token length").def(32); + // auto& seq_len = Argparse::add("--seq_len").help("Input sequence length").def(800); + // auto& gen_len = Argparse::add("--gen_len").help("Generate token length").def(32); Argparse::parse(argc, argv); @@ -37,22 +36,35 @@ MLLM_MAIN({ config.vocab_size = llama_cfg.vocab_size; config.context_len = 1024; config.ar_len = ar_len.get(); + config.type = "llama3"; // Note: Using Qwen3 tokenizer as a placeholder. // For production use, you should implement a Llama3Tokenizer or use // the appropriate tokenizer for your model. - auto tokenizer = mllm::models::llama::TinyLlamaTokenizer(tokenizer_path.get()); + auto tokenizer = mllm::models::llama::LlamaTokenizer(tokenizer_path.get()); - auto input_tensor = tokenizer.convertMessage({{ - .role = "user", - .content = "hello", - }}); + // auto input_tensor = tokenizer.convertMessage({{ + // .role = "user", + // .content = "hello", + // }}); - input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); + // input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); - // DBG: - mllm::print(input_tensor["sequence"].shape()); - mllm::print(input_tensor["sequence"]); + // // DBG: + // mllm::print(input_tensor["sequence"].shape()); + // mllm::print(input_tensor["sequence"]); + + // Runner runner(config, &tokenizer); + // if (!runner.load()) { + // std::cerr << "Failed to load model\n"; + // return 1; + // } + + std::string prompt_text; + fmt::print("💬 Prompt text (or 'exit/quit'): "); + std::getline(std::cin, prompt_text); + + auto input_tensor = tokenizer.convertMessage({{.role = "user", .content = prompt_text}}); Runner runner(config, &tokenizer); if (!runner.load()) { @@ -60,8 +72,8 @@ MLLM_MAIN({ return 1; } - runner.generate( - input_tensor["sequence"], gen_len.get(), [](const std::string& token) { std::cout << token << std::flush; }, true); + runner.generate(input_tensor["sequence"], config.context_len, + [](const std::string& token) { std::cout << token << std::flush; }); std::cout << "\n"; return 0; diff --git a/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp b/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp index 8943d6cec..edf7b565a 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp @@ -3,11 +3,14 @@ #pragma once +#include #include "mllm/core/DataTypes.hpp" namespace mllm::qnn::aot { struct QnnAOTConfig { + std::string type = "qwen3"; + int num_layers = 28; int num_heads = 12; int head_dim = 128; diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp index ae1fafa29..8d10e98e7 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp @@ -46,8 +46,22 @@ bool Runner::load() { // init token generator(decode) // TODO: EOS IDs auto eos_ids = std::make_unique>(); - eos_ids->insert(151643); - eos_ids->insert(151645); + // eos_ids->insert(151643); + // eos_ids->insert(151645); + + // Dynamically determine the currently loaded model based on the model name. + if (config_.type == "llama3") { + eos_ids->insert(128001); // <|end_of_text|> + eos_ids->insert(128008); // <|eom_id|> + eos_ids->insert(128009); // <|eot_id|> + } else if (config_.type == "qwen2") { + eos_ids->insert(151643); + eos_ids->insert(151645); + } else { + // qwen3 + eos_ids->insert(151643); + eos_ids->insert(151645); + } token_generator_ = std::make_unique>(tokenizer_, kv_manager_.get(), std::move(eos_ids), config_); diff --git a/mllm/models/llama/tokenization_llama.hpp b/mllm/models/llama/tokenization_llama.hpp new file mode 100644 index 000000000..9a0bc46a1 --- /dev/null +++ b/mllm/models/llama/tokenization_llama.hpp @@ -0,0 +1,233 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" + +namespace mllm::models::llama { + +inline bool llama3TokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d", + L"'S", L"'T", L"'RE", L"'VE", L"'M", L"'LL", L"'D"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + { + size_t original_pos = pos; + matched.clear(); + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + } + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } + pos = original_pos; + } + + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + { + size_t start = pos; + if (str[pos] == L' ') { ++pos; } + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + matched = str.substr(start, pos - start); + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } + pos = start; + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline void llama3Regex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (llama3TokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } +} + +struct LlamaMessage { + std::string role; + std::string content; +}; + +class LlamaTokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit LlamaTokenizer(const std::string& file_path, bool add_bos = true) : add_bos_(add_bos) { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + + bpe_.initFromSentencePieceJson(file_path); + + special_tokens_trie_.add(L"<|begin_of_text|>"); + special_tokens_trie_.add(L"<|end_of_text|>"); + special_tokens_trie_.add(L"<|start_header_id|>"); + special_tokens_trie_.add(L"<|end_header_id|>"); + special_tokens_trie_.add(L"<|eot_id|>"); + } + + std::string getSystemPromptPrefix() { + std::time_t t = std::time(nullptr); + std::tm tm_ = *std::localtime(&t); + std::ostringstream oss; + oss << std::put_time(&tm_, "%d %b %Y"); + return "Cutting Knowledge Date: December 2023\nToday Date: " + oss.str() + "\n\n"; + } + + inline std::string applyChatTemplate(const std::vector& messages, bool add_generation_prompt = true) { + std::string result = ""; + if (add_bos_) result += "<|begin_of_text|>"; + for (const auto& msg : messages) { + std::string content = msg.content; + if (msg.role == "system") content = getSystemPromptPrefix() + content; + result += "<|start_header_id|>" + msg.role + "<|end_header_id|>\n\n" + content + "<|eot_id|>"; + } + if (add_generation_prompt) result += "<|start_header_id|>assistant<|end_header_id|>\n\n"; + return result; + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + llama3Regex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + auto bpe_ts = bpe_._bpe(mapped_str); + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + } + return ret; + } + + std::vector tokenize(const std::string& str) override { + std::string processed_str = str; + bool text_has_bos = (processed_str.find("<|begin_of_text|>") == 0); + if (add_bos_ && !text_has_bos) { processed_str = "<|begin_of_text|>" + processed_str; } + + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(processed_str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { + if (bytes_2_unicode_dict_inverse_.count(c)) { + utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); + } else { + return str; + } + } + return mllm::preprocessor::utf8string2WideString(utf_8_str); + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = + Tensor::empty({1, (int32_t)ids.size()}, kInt64, kCPU).setMemType(kExtraInput).setName("llama-tokenizer-i0").alloc(); + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + return ret; + } + + std::vector encode(const std::string& str) { + auto sub_tokens = tokenize(str); + std::vector ret; + for (auto& token : sub_tokens) { ret.emplace_back(bpe_._lookup_vocab(token)); } + return ret; + } + + std::string decode(const std::vector& ids) { + std::string ret; + for (auto& each_id : ids) { + auto wstr = detokenize(each_id); + ret += mllm::preprocessor::wideString2Utf8String(wstr); + } + return ret; + } + + ARGenerationOutputPast convertMessage(const std::vector& messages) { + auto applied_string = applyChatTemplate(messages, true); + auto sequence_str = tokenize(applied_string); + std::vector ids; + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + Tensor sequence = + Tensor::empty({1, (int32_t)ids.size()}, kInt64, kCPU).setMemType(kNormal).setName("llama-tokenizer-i0").alloc(); + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return { + {"sequence", sequence}, + }; + } + + private: + bool add_bos_ = true; + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; +}; + +} // namespace mllm::models::llama \ No newline at end of file diff --git a/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py index 6b65f34b9..73e5aaffc 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py @@ -20,9 +20,20 @@ from typing import Callable, Optional, Union import torch +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( + QLinearLPBQ, +) + +# Replace linear, rms_norm with: +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm from torch import nn from torch.nn import functional as F - from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin @@ -40,6 +51,7 @@ ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig from transformers.processing_utils import Unpack from transformers.utils import ( TransformersKwargs, @@ -49,20 +61,6 @@ ) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import check_model_inputs -from transformers.models.llama.configuration_llama import LlamaConfig - -# Replace linear, rms_norm with: -from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( - QLinearLPBQ, -) -from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( - ActivationQDQ, - FixedActivationQDQ, -) -from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver - logger = logging.get_logger(__name__) @@ -302,8 +300,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int): # QDQ self.q_proj_input_qdq = ActivationQDQ(bits=16) - self.k_proj_input_qdq = ActivationQDQ(bits=16) - self.v_proj_input_qdq = ActivationQDQ(bits=16) + # self.k_proj_input_qdq = ActivationQDQ(bits=16) + # self.v_proj_input_qdq = ActivationQDQ(bits=16) self.q_proj_output_qdq = ActivationQDQ(bits=16) self.k_proj_output_qdq = ActivationQDQ(bits=16) @@ -336,13 +334,13 @@ def __init__(self, config: LlamaConfig, layer_idx: int): ) self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) self.k_rope_concat_observer.add_observer( - self.k_proj_input_qdq.fake_quant.activation_post_process + self.k_proj_output_qdq.fake_quant.activation_post_process ) self.k_rope_concat_observer.add_observer( self.k_rope_neg_half_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( - self.q_proj_input_qdq.fake_quant.activation_post_process + self.q_proj_output_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( self.q_rope_neg_half_qdq.fake_quant.activation_post_process @@ -384,12 +382,12 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = self.q_proj_output_qdq(query_states) - hidden_states_k = self.k_proj_input_qdq(hidden_states) - key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2) + # hidden_states_k = self.k_proj_input_qdq(hidden_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj_output_qdq(key_states) - hidden_states_v = self.v_proj_input_qdq(hidden_states) - value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2) + # hidden_states_v = self.v_proj_input_qdq(hidden_states) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings cos = cos.unsqueeze(1) @@ -399,7 +397,7 @@ def forward( + self.q_rope_mul_1_output_qdq( rotate_half( query_states, - self.q_proj_input_qdq.fake_quant.activation_post_process, + self.q_proj_output_qdq.fake_quant.activation_post_process, self.q_rope_neg_half_qdq, self.q_rope_concat_observer, ) @@ -411,7 +409,7 @@ def forward( + self.k_rope_mul_1_output_qdq( rotate_half( key_states, - self.k_proj_input_qdq.fake_quant.activation_post_process, + self.k_proj_output_qdq.fake_quant.activation_post_process, self.k_rope_neg_half_qdq, self.k_rope_concat_observer, ) diff --git a/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py index 45af95f8f..38ca7df62 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py @@ -13,7 +13,9 @@ QLinearW8A16_PerChannelSym, ) from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.mobile.backends.qualcomm.transformers.llama.modeling_llama import LlamaForCausalLM +from pymllm.mobile.backends.qualcomm.transformers.llama.modeling_llama import ( + LlamaForCausalLM, +) from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver @@ -194,6 +196,7 @@ def convert_weight(m): if isinstance(m, QEmbedding): m.convert_to_deploy() + def _check_datasets_compatibility(): try: ds_ver = version("datasets") @@ -210,6 +213,7 @@ def _check_datasets_compatibility(): "datasets version. Please use datasets==2.21.0." ) + class LlamaQuantizer: def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): self.tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -251,6 +255,12 @@ def compile(self): print("Compile done.") def infer(self, prompt: str): + messages = [{"role": "user", "content": prompt}] + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) # Llama models typically don't use chat templates, so we tokenize directly model_inputs = self.tokenizer([prompt], return_tensors="pt").to( self.model.device @@ -292,6 +302,7 @@ def calibrate(self, num_samples=64, max_seq_length=512): # Use streaming=True to download and process on the fly, without downloading the full几十G dataset _check_datasets_compatibility() from modelscope.msdatasets import MsDataset + dataset = MsDataset.load( "modelscope/wikitext", subset_name="wikitext-103-v1", @@ -309,11 +320,12 @@ def calibrate(self, num_samples=64, max_seq_length=512): if samples_processed >= num_samples: break - if len(entry["text"].strip()) < 1024: + text = entry["text"].strip() + if len(text) < 50: continue # Llama models typically don't use chat templates - text = entry["text"] + # text = entry["text"] model_inputs = self.tokenizer( [text], return_tensors="pt", @@ -322,16 +334,18 @@ def calibrate(self, num_samples=64, max_seq_length=512): padding=False, ).to(self.model.device) - # Only need Prefill stage: directly call forward - # This will trigger observer update statistics in ActivationQDQ - self.model.generate( - **model_inputs, - max_new_tokens=1, - do_sample=False, - temperature=None, - top_p=None, - top_k=None, - ) + self.model(**model_inputs) + + # # Only need Prefill stage: directly call forward + # # This will trigger observer update statistics in ActivationQDQ + # self.model.generate( + # **model_inputs, + # max_new_tokens=1, + # do_sample=False, + # temperature=None, + # top_p=None, + # top_k=None, + # ) samples_processed += 1 pbar.update(1) diff --git a/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py b/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py index a43d8b7ea..f8ad9ec56 100644 --- a/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py @@ -186,12 +186,12 @@ def __init__(self, config: Qwen2Config, layer_idx: int): # QDQ self.q_proj_input_qdq = ActivationQDQ(bits=16) - self.k_proj_input_qdq = ActivationQDQ(bits=16) + # self.k_proj_input_qdq = ActivationQDQ(bits=16) self.q_proj_output_qdq = ActivationQDQ(bits=16) self.k_proj_output_qdq = ActivationQDQ(bits=16) - self.v_proj_input_qdq = ActivationQDQ(bits=16) + # self.v_proj_input_qdq = ActivationQDQ(bits=16) self.q_rope_mul_0_output_qdq = ActivationQDQ(bits=16) self.q_rope_mul_1_output_qdq = ActivationQDQ(bits=16) self.q_rope_add_0_output_qdq = ActivationQDQ(bits=16) @@ -220,13 +220,13 @@ def __init__(self, config: Qwen2Config, layer_idx: int): ) self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) self.k_rope_concat_observer.add_observer( - self.k_proj_input_qdq.fake_quant.activation_post_process + self.k_proj_output_qdq.fake_quant.activation_post_process ) self.k_rope_concat_observer.add_observer( self.k_rope_neg_half_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( - self.q_proj_input_qdq.fake_quant.activation_post_process + self.q_proj_output_qdq.fake_quant.activation_post_process ) self.q_rope_concat_observer.add_observer( self.q_rope_neg_half_qdq.fake_quant.activation_post_process @@ -268,12 +268,12 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = self.q_proj_output_qdq(query_states) - hidden_states_k = self.k_proj_input_qdq(hidden_states) - key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2) + # hidden_states_k = self.k_proj_input_qdq(hidden_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj_output_qdq(key_states) - hidden_states_v = self.v_proj_input_qdq(hidden_states) - value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2) + # hidden_states_v = self.v_proj_input_qdq(hidden_states) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings cos = cos.unsqueeze(1) @@ -283,7 +283,7 @@ def forward( + self.q_rope_mul_1_output_qdq( rotate_half( query_states, - self.q_proj_input_qdq.fake_quant.activation_post_process, + self.q_proj_output_qdq.fake_quant.activation_post_process, self.q_rope_neg_half_qdq, self.q_rope_concat_observer, ) @@ -295,7 +295,7 @@ def forward( + self.k_rope_mul_1_output_qdq( rotate_half( key_states, - self.k_proj_input_qdq.fake_quant.activation_post_process, + self.k_proj_output_qdq.fake_quant.activation_post_process, self.k_rope_neg_half_qdq, self.k_rope_concat_observer, )