Skip to content

Commit e463bf1

Browse files
committed
Performance improvement for xgboost exporter
1 parent e1e0432 commit e463bf1

1 file changed

Lines changed: 4 additions & 12 deletions

File tree

nyoka/xgboost/xgboost_to_pmml.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,7 @@ def get_segments_for_xgbr(model, derived_col_names, feature_names, target_name,
240240
Nyoka's Segment object
241241
242242
"""
243-
segments = list()
244-
get_nodes_in_json_format = []
245-
for i in range(model.n_estimators):
246-
get_nodes_in_json_format.append(json.loads(model._Booster.get_dump(dump_format='json')[i]))
243+
get_nodes_in_json_format = model._Booster.get_dump(dump_format='json')
247244
segmentation = pml.Segmentation(multipleModelMethod=MULTIPLE_MODEL_METHOD.SUM,
248245
Segment=generate_Segments_Equal_To_Estimators(get_nodes_in_json_format, derived_col_names,
249246
feature_names))
@@ -354,7 +351,7 @@ def generate_Segments_Equal_To_Estimators(val, derived_col_names, col_names):
354351
main_node = pml.Node(True_=pml.True_())
355352
m_flds = []
356353
mining_field_for_innner_segments = col_names
357-
create_node(val[i], main_node, derived_col_names)
354+
create_node(json.loads(val[i]), main_node, derived_col_names)
358355

359356
for name in mining_field_for_innner_segments:
360357
m_flds.append(pml.MiningField(name=name))
@@ -436,9 +433,7 @@ def get_segments_for_xgbc(model, derived_col_names, feature_names, target_name,
436433
segments = list()
437434

438435
if model.n_classes_ == 2:
439-
get_nodes_in_json_format=[]
440-
for i in range(model.n_estimators):
441-
get_nodes_in_json_format.append(json.loads(model._Booster.get_dump(dump_format='json')[i]))
436+
get_nodes_in_json_format=model._Booster.get_dump(dump_format='json')
442437
mining_schema_for_1st_segment = mining_Field_For_First_Segment(feature_names)
443438
outputField = list()
444439
outputField.append(pml.OutputField(name="xgbValue", optype=OPTYPE.CONTINUOUS, dataType=DATATYPE.FLOAT,
@@ -457,10 +452,7 @@ def get_segments_for_xgbc(model, derived_col_names, feature_names, target_name,
457452

458453
segments.append(last_segment)
459454
else:
460-
461-
get_nodes_in_json_format = []
462-
for i in range(model.n_estimators * model.n_classes_):
463-
get_nodes_in_json_format.append(json.loads(model._Booster.get_dump(dump_format='json')[i]))
455+
get_nodes_in_json_format = model._Booster.get_dump(dump_format='json')
464456
oField = list()
465457
for index in range(0, model.n_classes_):
466458
inner_segment = []

0 commit comments

Comments
 (0)