66#include " ../../models/llama/llama_attention.hpp"
77#include " infinicore/device.hpp"
88#include " infinicore/nn/module.hpp"
9+ #include " infinicore/nn/rope.hpp"
910#include " infinicore/tensor.hpp"
1011#include < pybind11/numpy.h>
1112#include < pybind11/pybind11.h>
@@ -69,7 +70,8 @@ inline void bind_llama(py::module &m) {
6970 .def_readwrite (" pretraining_tp" , &LlamaConfig::pretraining_tp)
7071 .def_readwrite (" name_or_path" , &LlamaConfig::name_or_path)
7172 .def_readwrite (" pad_token_id" , &LlamaConfig::pad_token_id)
72- .def_property (" bos_token_id" , [](const LlamaConfig &self) {
73+ .def_property (
74+ " bos_token_id" , [](const LlamaConfig &self) {
7375 // Always return as list to match Python config format
7476 return py::cast (self.bos_token_id ); }, [](LlamaConfig &self, py::object value) {
7577 // Accept both single int and list
@@ -80,7 +82,8 @@ inline void bind_llama(py::module &m) {
8082 } else {
8183 throw py::type_error (" bos_token_id must be int or list of ints" );
8284 } })
83- .def_property (" eos_token_id" , [](const LlamaConfig &self) {
85+ .def_property (
86+ " eos_token_id" , [](const LlamaConfig &self) {
8487 // Always return as list to match Python config format
8588 return py::cast (self.eos_token_id ); }, [](LlamaConfig &self, py::object value) {
8689 // Accept both single int and list
@@ -91,6 +94,86 @@ inline void bind_llama(py::module &m) {
9194 } else {
9295 throw py::type_error (" eos_token_id must be int or list of ints" );
9396 } })
97+ .def_property (
98+ " rope_scaling" ,
99+
100+ // ---------- getter ----------
101+ [](const LlamaConfig &self) -> py::object {
102+ if (!self.rope_scaling ) {
103+ return py::none ();
104+ }
105+
106+ using ScalingConfig = infinicore::nn::RoPE::ScalingConfig;
107+ using LongRopeConfig = infinicore::nn::RoPE::LongRopeConfig;
108+
109+ py::dict d;
110+
111+ if (auto *lr = dynamic_cast <const LongRopeConfig *>(self.rope_scaling .get ())) {
112+ d[" type" ] = " longrope" ;
113+ d[" rope_type" ] = " longrope" ;
114+ d[" factor" ] = lr->factor ();
115+ d[" original_max_position_embeddings" ] = lr->original_max_position_embeddings ();
116+ d[" short_factor" ] = lr->short_factor ();
117+ d[" long_factor" ] = lr->long_factor ();
118+ } else {
119+ throw std::runtime_error (" Unknown RoPE scaling type" );
120+ }
121+
122+ return std::move (d);
123+ },
124+
125+ // ---------- setter ----------
126+ [](LlamaConfig &self, py::object value) {
127+ if (value.is_none ()) {
128+ self.rope_scaling .reset ();
129+ return ;
130+ }
131+
132+ if (!py::isinstance<py::dict>(value)) {
133+ throw py::type_error (" rope_scaling must be a dict or None" );
134+ }
135+
136+ py::dict d = value.cast <py::dict>();
137+
138+ auto get_str = [&](const char *k) {
139+ if (!d.contains (k)) {
140+ throw py::key_error (k);
141+ }
142+ return py::cast<std::string>(d[k]);
143+ };
144+
145+ std::string type = d.contains (" rope_type" )
146+ ? py::cast<std::string>(d[" rope_type" ])
147+ : get_str (" type" );
148+
149+ if (type == " longrope" ) {
150+ using LongRopeConfig = infinicore::nn::RoPE::LongRopeConfig;
151+
152+ if (!d.contains (" short_factor" ) || !d.contains (" long_factor" ) || !d.contains (" original_max_position_embeddings" )) {
153+ throw py::value_error (
154+ " longrope requires short_factor, long_factor, "
155+ " original_max_position_embeddings" );
156+ }
157+
158+ std::vector<float > short_factor = py::cast<std::vector<float >>(d[" short_factor" ]);
159+ std::vector<float > long_factor = py::cast<std::vector<float >>(d[" long_factor" ]);
160+
161+ size_t original_max_position_embeddings = py::cast<size_t >(d[" original_max_position_embeddings" ]);
162+
163+ float factor = 1 .0f ;
164+ if (d.contains (" factor" )) {
165+ factor = py::cast<float >(d[" factor" ]);
166+ }
167+
168+ self.rope_scaling = std::make_shared<LongRopeConfig>(
169+ std::move (short_factor),
170+ std::move (long_factor),
171+ original_max_position_embeddings,
172+ factor);
173+ } else {
174+ throw py::value_error (" Unsupported rope_scaling type: " + type);
175+ }
176+ })
94177 .def (" validate" , &LlamaConfig::validate)
95178 .def (" kv_dim" , &LlamaConfig::kv_dim)
96179 // Add __dir__ to make attributes discoverable via dir() in Python
@@ -108,6 +191,7 @@ inline void bind_llama(py::module &m) {
108191 dir_list.append (" hidden_act" );
109192 dir_list.append (" model_type" );
110193 dir_list.append (" rope_theta" );
194+ dir_list.append (" rope_scaling" );
111195 dir_list.append (" attention_bias" );
112196 dir_list.append (" attention_output_bias" );
113197 dir_list.append (" mlp_bias" );
0 commit comments