Skip to content

Commit de3e6b9

Browse files
authored
Merge pull request #187 from InfiniTensor/issue/186
issue/186 support longrope
2 parents c1a3ab2 + fc454c7 commit de3e6b9

4 files changed

Lines changed: 137 additions & 65 deletions

File tree

csrc/models/llama/llama_config.hpp

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
#include "../infinilm_model.hpp"
99

10+
#include <infinicore/nn/rope.hpp>
11+
1012
namespace infinilm::models::llama {
1113

1214
/**
@@ -20,41 +22,43 @@ struct LlamaConfig : public InfinilmModel::Config {
2022
infinicore::DataType dtype = infinicore::DataType::F32;
2123

2224
// Vocabulary and embedding
23-
size_t vocab_size = 32000; // Vocabulary size
24-
size_t hidden_size = 4096; // Hidden dimension size
25-
size_t intermediate_size = 11008; // MLP intermediate dimension
25+
size_t vocab_size = 32000; // Vocabulary size
26+
size_t hidden_size = 4096; // Hidden dimension size
27+
size_t intermediate_size = 11008; // MLP intermediate dimension
2628

2729
// Architecture
28-
size_t num_hidden_layers = 32; // Number of decoder layers
29-
size_t num_attention_heads = 32; // Number of attention heads
30-
size_t num_key_value_heads = 32; // Number of key-value heads (for GQA)
31-
size_t head_dim = 128; // Attention head dimension (hidden_size / num_attention_heads)
30+
size_t num_hidden_layers = 32; // Number of decoder layers
31+
size_t num_attention_heads = 32; // Number of attention heads
32+
size_t num_key_value_heads = 32; // Number of key-value heads (for GQA)
33+
size_t head_dim = 128; // Attention head dimension (hidden_size / num_attention_heads)
3234

3335
// Position embeddings
34-
size_t max_position_embeddings = 2048; // Maximum sequence length
35-
double rope_theta = 10000.0; // RoPE base frequency
36+
size_t max_position_embeddings = 2048; // Maximum sequence length
37+
double rope_theta = 10000.0; // RoPE base frequency
38+
39+
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> rope_scaling = nullptr; // RoPE scaling type
3640

3741
// Normalization
38-
double rms_norm_eps = 1e-6; // RMSNorm epsilon
42+
double rms_norm_eps = 1e-6; // RMSNorm epsilon
3943

4044
// Activation
41-
std::string hidden_act = "silu"; // Activation function (typically "silu")
42-
std::string model_type = "llama"; // Model type identifier (matches HF configs)
45+
std::string hidden_act = "silu"; // Activation function (typically "silu")
46+
std::string model_type = "llama"; // Model type identifier (matches HF configs)
4347

4448
// Optional features
45-
bool use_cache = true; // Whether to use KV cache
46-
bool attention_bias = true; // Whether to use bias in Q/K/V projections (default true for 9G7B compatibility)
47-
bool attention_output_bias = false; // Whether to use bias in output projection (o_proj)
48-
bool mlp_bias = false; // Whether to use bias in MLP projections
49-
bool tie_word_embeddings = false; // Whether to tie input/output embeddings
49+
bool use_cache = true; // Whether to use KV cache
50+
bool attention_bias = true; // Whether to use bias in Q/K/V projections (default true for 9G7B compatibility)
51+
bool attention_output_bias = false; // Whether to use bias in output projection (o_proj)
52+
bool mlp_bias = false; // Whether to use bias in MLP projections
53+
bool tie_word_embeddings = false; // Whether to tie input/output embeddings
5054

5155
// Training/initialization parameters
52-
double attention_dropout = 0.0; // Dropout ratio for attention probabilities
53-
double initializer_range = 0.02; // Standard deviation for weight initialization
54-
size_t pretraining_tp = 1; // Tensor parallelism rank used during pretraining
56+
double attention_dropout = 0.0; // Dropout ratio for attention probabilities
57+
double initializer_range = 0.02; // Standard deviation for weight initialization
58+
size_t pretraining_tp = 1; // Tensor parallelism rank used during pretraining
5559

5660
// Model metadata
57-
std::string name_or_path = ""; // Model name or path identifier
61+
std::string name_or_path = ""; // Model name or path identifier
5862

5963
// Token IDs
6064
int64_t pad_token_id = -1; // Padding token ID (optional)

csrc/models/llama/llama_model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
3434
// Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing
3535
INFINICORE_NN_MODULE_INIT(rotary_emb, config.head_dim, config.max_position_embeddings,
3636
config.rope_theta, infinicore::nn::RoPE::Algo::GPT_NEOX,
37-
dtype, device);
37+
dtype, device, config.rope_scaling);
3838

3939
for (auto &layer : layers_) {
4040
if (layer) {

csrc/pybind11/models/llama.hpp

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "../../models/llama/llama_attention.hpp"
77
#include "infinicore/device.hpp"
88
#include "infinicore/nn/module.hpp"
9+
#include "infinicore/nn/rope.hpp"
910
#include "infinicore/tensor.hpp"
1011
#include <pybind11/numpy.h>
1112
#include <pybind11/pybind11.h>
@@ -69,7 +70,8 @@ inline void bind_llama(py::module &m) {
6970
.def_readwrite("pretraining_tp", &LlamaConfig::pretraining_tp)
7071
.def_readwrite("name_or_path", &LlamaConfig::name_or_path)
7172
.def_readwrite("pad_token_id", &LlamaConfig::pad_token_id)
72-
.def_property("bos_token_id", [](const LlamaConfig &self) {
73+
.def_property(
74+
"bos_token_id", [](const LlamaConfig &self) {
7375
// Always return as list to match Python config format
7476
return py::cast(self.bos_token_id); }, [](LlamaConfig &self, py::object value) {
7577
// Accept both single int and list
@@ -80,7 +82,8 @@ inline void bind_llama(py::module &m) {
8082
} else {
8183
throw py::type_error("bos_token_id must be int or list of ints");
8284
} })
83-
.def_property("eos_token_id", [](const LlamaConfig &self) {
85+
.def_property(
86+
"eos_token_id", [](const LlamaConfig &self) {
8487
// Always return as list to match Python config format
8588
return py::cast(self.eos_token_id); }, [](LlamaConfig &self, py::object value) {
8689
// Accept both single int and list
@@ -91,6 +94,86 @@ inline void bind_llama(py::module &m) {
9194
} else {
9295
throw py::type_error("eos_token_id must be int or list of ints");
9396
} })
97+
.def_property(
98+
"rope_scaling",
99+
100+
// ---------- getter ----------
101+
[](const LlamaConfig &self) -> py::object {
102+
if (!self.rope_scaling) {
103+
return py::none();
104+
}
105+
106+
using ScalingConfig = infinicore::nn::RoPE::ScalingConfig;
107+
using LongRopeConfig = infinicore::nn::RoPE::LongRopeConfig;
108+
109+
py::dict d;
110+
111+
if (auto *lr = dynamic_cast<const LongRopeConfig *>(self.rope_scaling.get())) {
112+
d["type"] = "longrope";
113+
d["rope_type"] = "longrope";
114+
d["factor"] = lr->factor();
115+
d["original_max_position_embeddings"] = lr->original_max_position_embeddings();
116+
d["short_factor"] = lr->short_factor();
117+
d["long_factor"] = lr->long_factor();
118+
} else {
119+
throw std::runtime_error("Unknown RoPE scaling type");
120+
}
121+
122+
return std::move(d);
123+
},
124+
125+
// ---------- setter ----------
126+
[](LlamaConfig &self, py::object value) {
127+
if (value.is_none()) {
128+
self.rope_scaling.reset();
129+
return;
130+
}
131+
132+
if (!py::isinstance<py::dict>(value)) {
133+
throw py::type_error("rope_scaling must be a dict or None");
134+
}
135+
136+
py::dict d = value.cast<py::dict>();
137+
138+
auto get_str = [&](const char *k) {
139+
if (!d.contains(k)) {
140+
throw py::key_error(k);
141+
}
142+
return py::cast<std::string>(d[k]);
143+
};
144+
145+
std::string type = d.contains("rope_type")
146+
? py::cast<std::string>(d["rope_type"])
147+
: get_str("type");
148+
149+
if (type == "longrope") {
150+
using LongRopeConfig = infinicore::nn::RoPE::LongRopeConfig;
151+
152+
if (!d.contains("short_factor") || !d.contains("long_factor") || !d.contains("original_max_position_embeddings")) {
153+
throw py::value_error(
154+
"longrope requires short_factor, long_factor, "
155+
"original_max_position_embeddings");
156+
}
157+
158+
std::vector<float> short_factor = py::cast<std::vector<float>>(d["short_factor"]);
159+
std::vector<float> long_factor = py::cast<std::vector<float>>(d["long_factor"]);
160+
161+
size_t original_max_position_embeddings = py::cast<size_t>(d["original_max_position_embeddings"]);
162+
163+
float factor = 1.0f;
164+
if (d.contains("factor")) {
165+
factor = py::cast<float>(d["factor"]);
166+
}
167+
168+
self.rope_scaling = std::make_shared<LongRopeConfig>(
169+
std::move(short_factor),
170+
std::move(long_factor),
171+
original_max_position_embeddings,
172+
factor);
173+
} else {
174+
throw py::value_error("Unsupported rope_scaling type: " + type);
175+
}
176+
})
94177
.def("validate", &LlamaConfig::validate)
95178
.def("kv_dim", &LlamaConfig::kv_dim)
96179
// Add __dir__ to make attributes discoverable via dir() in Python
@@ -108,6 +191,7 @@ inline void bind_llama(py::module &m) {
108191
dir_list.append("hidden_act");
109192
dir_list.append("model_type");
110193
dir_list.append("rope_theta");
194+
dir_list.append("rope_scaling");
111195
dir_list.append("attention_bias");
112196
dir_list.append("attention_output_bias");
113197
dir_list.append("mlp_bias");

test/bench/test_benchmark.py

Lines changed: 25 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def render_ceval(_tokenizer, conversation):
368368
def render_mmlu(_tokenizer, question, choices):
369369
"""Render MMLU question and choices to input content"""
370370
choices_text = "\n".join(
371-
[f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)]
371+
[f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
372372
)
373373
instruction = (
374374
"You are a multiple-choice question solver. "
@@ -924,7 +924,9 @@ def _load_mmlu_subject(subj):
924924
splits_to_load = (
925925
["test"]
926926
if split == "test"
927-
else ["validation"] if split == "val" else ["validation", "test"]
927+
else ["validation"]
928+
if split == "val"
929+
else ["validation", "test"]
928930
)
929931
# Load each subject individually from hardcoded list, excluding "all"
930932
for subject_name in mmlu_subjects:
@@ -946,7 +948,9 @@ def _load_mmlu_subject(subj):
946948
splits_to_load = (
947949
["test"]
948950
if split == "test"
949-
else ["validation"] if split == "val" else ["validation", "test"]
951+
else ["validation"]
952+
if split == "val"
953+
else ["validation", "test"]
950954
)
951955
records = []
952956
for sp in splits_to_load:
@@ -980,14 +984,13 @@ def load_subject_samples(subj_name):
980984
all_results = []
981985

982986
for subj in subject_list:
983-
print(f"\n{'='*60}")
987+
print(f"\n{'=' * 60}")
984988
print(f"Evaluating subject: {subj}")
985-
print(f"{'='*60}\n")
989+
print(f"{'=' * 60}\n")
986990

987991
try:
988992
samples, actual_subj_name = load_subject_samples(subj)
989993
print(f"Loaded {len(samples)} samples for subject: {actual_subj_name}")
990-
991994
# Limit number of samples if specified
992995
if num_samples is not None and num_samples > 0:
993996
original_count = len(samples)
@@ -996,37 +999,9 @@ def load_subject_samples(subj_name):
996999
f"Limited to {len(samples)} samples for validation (from {original_count} total)"
9971000
)
9981001

999-
# Test with first sample if available
1000-
if len(samples) > 0:
1001-
sample = samples[0]
1002-
if benchmark == "ceval":
1003-
input_content = f"'question':{sample['question']},'A': {sample['A']}, 'B':{sample['B']}, 'C': {sample['C']},'D': {sample['D']}。"
1004-
test_conversation = [
1005-
{
1006-
"role": "system",
1007-
"content": "请从question的A,B,C,D四个选项中选择正确的选项。例如,标准答案:A。",
1008-
},
1009-
{"role": "user", "content": input_content},
1010-
]
1011-
test_output = model.generate(
1012-
test_conversation,
1013-
max_steps=max_new_tokens,
1014-
topp_=1.0,
1015-
topk_=1,
1016-
temperature_=1.0,
1017-
)
1018-
elif benchmark == "mmlu":
1019-
question = sample["question"]
1020-
choices = sample["choices"]
1021-
test_output = model.generate(
1022-
question,
1023-
choices,
1024-
max_steps=max_new_tokens,
1025-
topp_=1.0,
1026-
topk_=1,
1027-
temperature_=1.0,
1028-
)
1029-
print(f"\nTest output: {test_output}\n")
1002+
if len(samples) == 0:
1003+
print(f"No samples found for subject: {actual_subj_name}")
1004+
continue
10301005

10311006
# Evaluate samples for this subject
10321007
result = evaluate_samples(
@@ -1044,13 +1019,22 @@ def load_subject_samples(subj_name):
10441019
model.destroy_model_instance()
10451020

10461021
# Calculate overall results
1022+
print(f"\n{'=' * 60}")
1023+
print("OVERALL RESULTS")
1024+
print(f"{'=' * 60}")
1025+
if len(all_results) == 0:
1026+
print("No tests were run.")
1027+
return
1028+
elif len(all_results) > 1:
1029+
for r in all_results:
1030+
print(
1031+
f"Subject '{r['subject']}': {r['correct']}/{r['total']} = {r['accuracy']:.2%}"
1032+
)
10471033
overall_correct = sum(r["correct"] for r in all_results)
10481034
overall_total = sum(r["total"] for r in all_results)
10491035
overall_accuracy = overall_correct / overall_total if overall_total > 0 else 0.0
10501036

1051-
print(f"\n{'='*60}")
1052-
print("OVERALL RESULTS")
1053-
print(f"{'='*60}")
1037+
print(f"{'=' * 60}")
10541038
if benchmark == "ceval":
10551039
print(
10561040
f"Overall 成绩: {overall_correct}/{overall_total} = {overall_accuracy:.2%}"
@@ -1062,7 +1046,7 @@ def load_subject_samples(subj_name):
10621046

10631047
print(f"Total Latency: {TOTAL_TIME} seconds")
10641048
print(f"Total Tokens Processed: {TOTAL_TOKENS} tokens")
1065-
print(f"Overall Throughput: {TOTAL_TOKENS/TOTAL_TIME:.2f} tokens/s")
1049+
print(f"Overall Throughput: {TOTAL_TOKENS / TOTAL_TIME:.2f} tokens/s")
10661050

10671051
# Write CSV if output path is specified
10681052
if output_csv:

0 commit comments

Comments
 (0)