diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index dc08bee65d3d9..28f9b40ac095d 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -292,7 +292,7 @@ def _fit( ) # Determine output settings - n_samples, self.n_features_in_ = X.shape + n_samples, self.n_features_in_ = X.shape[0], X.shape[1] # Do preprocessing if 'y' is passed is_classification = False