Skip to content

Commit 0e8ea33

Browse files
committed
Initial commit
1 parent 6d977d8 commit 0e8ea33

2 files changed

Lines changed: 12 additions & 0 deletions

File tree

sklbench/benchmarks/sklearn_estimator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,16 @@ def get_subset_metrics_of_estimator(
134134
and isinstance(iterations[0], Union[Numeric, NumpyNumeric].__args__)
135135
):
136136
metrics.update({"iterations": int(iterations[0])})
137+
if hasattr(estimator_instance, "estimators_"):
138+
estimators_with_trees = [
139+
t
140+
for t in estimator_instance.estimators_
141+
if hasattr(t, "tree_") and hasattr(t.tree_, "node_count")
142+
]
143+
if estimators_with_trees:
144+
metrics["n_nodes"] = sum(
145+
t.tree_.node_count for t in estimators_with_trees
146+
)
137147
if task == "classification":
138148
y_pred = convert_to_numpy(estimator_instance.predict(x))
139149
metrics.update(

sklbench/report/implementation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171
# NB: 'n_clusters' is parameter of KMeans while
7272
# 'clusters' is number of computer clusters by DBSCAN
7373
"clusters",
74+
# tree ensembles
75+
"n_nodes",
7476
],
7577
"incomparable": [
7678
"1st-mean run ratio",

0 commit comments

Comments
 (0)