Skip to content

Commit ddc00f4

Browse files
icarosaderoCeliaBenquet
authored andcommitted
Fix deprecation warning force_all_finite -> ensure_all_finite for sklearn>=1.6 (AdaptiveMotorControlLab#206)
1 parent e8004ba commit ddc00f4

1 file changed

Lines changed: 18 additions & 4 deletions

File tree

cebra/integrations/sklearn/utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,26 @@
2222
import warnings
2323

2424
import numpy.typing as npt
25+
import packaging
26+
import sklearn
2527
import sklearn.utils.validation as sklearn_utils_validation
2628
import torch
2729

2830
import cebra.helper
2931

3032

33+
def _sklearn_check_array(array, **kwargs):
34+
# NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206
35+
# https://scikit-learn.org/1.6/modules/generated/sklearn.utils.check_array.html
36+
# force_all_finite was renamed to ensure_all_finite and will be removed in 1.8.
37+
if packaging.version.parse(
38+
sklearn.__version__) < packaging.version.parse("1.6"):
39+
if "ensure_all_finite" in kwargs:
40+
kwargs["force_all_finite"] = kwargs["ensure_all_finite"]
41+
del kwargs["ensure_all_finite"]
42+
return sklearn_utils_validation.check_array(array, **kwargs)
43+
44+
3145
def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
3246
"""Handle deprecated arguments of a function until they are replaced.
3347
@@ -74,16 +88,16 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
7488
Returns:
7589
The converted and validated array.
7690
"""
77-
return sklearn_utils_validation.check_array(
91+
return _sklearn_check_array(
7892
X,
7993
accept_sparse=False,
8094
accept_large_sparse=False,
8195
# NOTE: remove float16 because F.pad does not allow float16.
8296
dtype=("float32", "float64"),
8397
order=None,
8498
copy=False,
85-
force_all_finite=True,
8699
ensure_2d=True,
100+
ensure_all_finite=True,
87101
allow_nd=False,
88102
ensure_min_samples=min_samples,
89103
ensure_min_features=1,
@@ -106,15 +120,15 @@ def check_label_array(y: npt.NDArray, *, min_samples: int):
106120
Returns:
107121
The converted and validated labels.
108122
"""
109-
return sklearn_utils_validation.check_array(
123+
return _sklearn_check_array(
110124
y,
111125
accept_sparse=False,
112126
accept_large_sparse=False,
113127
dtype="numeric",
114128
order=None,
115129
copy=False,
116-
force_all_finite=True,
117130
ensure_2d=False,
131+
ensure_all_finite=True,
118132
allow_nd=False,
119133
ensure_min_samples=min_samples,
120134
)

0 commit comments

Comments
 (0)