@@ -68,7 +68,7 @@ def decomposition(
6868 output : pd .DataFrame ,
6969 * ,
7070 sensitivity_indices : np .ndarray ,
71- dec_limit : float = 1 ,
71+ dec_limit : float | None = None ,
7272 auto_ordering : bool = True ,
7373 states : list [int ] | None = None ,
7474 statistic : Literal ["mean" , "median" ] | None = "mean" ,
@@ -79,7 +79,7 @@ def decomposition(
7979 ----------
8080 inputs : DataFrame of shape (n_runs, n_factors)
8181 Input variables.
82- output : DataFrame of shape (n_runs, 1)
82+ output : DataFrame of shape (n_runs, 1) or (n_runs,)
8383 Target variable.
8484 sensitivity_indices : ndarray of shape (n_factors, 1)
8585 Sensitivity indices, combined effect of each input.
@@ -116,7 +116,7 @@ def decomposition(
116116 inputs [cat_col ] = codes
117117
118118 inputs = inputs .to_numpy ()
119- output = output .to_numpy ()
119+ output = output .to_numpy (). flatten ()
120120
121121 # 1. variables for decomposition
122122 var_order = np .argsort (sensitivity_indices )[::- 1 ]
@@ -125,26 +125,41 @@ def decomposition(
125125 sensitivity_indices = sensitivity_indices [var_order ]
126126
127127 if auto_ordering :
128- n_var_dec = np .where (np .cumsum (sensitivity_indices ) < dec_limit )[0 ].size
128+ # handle edge case where sensitivity indices don't sum exactly to 1.0
129+ if dec_limit is None :
130+ dec_limit = 0.8 * np .sum (sensitivity_indices )
131+
132+ cumulative_sum = np .cumsum (sensitivity_indices )
133+ indices_over_limit = np .where (cumulative_sum >= dec_limit )[0 ]
134+
135+ if indices_over_limit .size > 0 :
136+ n_var_dec = indices_over_limit [0 ] + 1
137+ else :
138+ n_var_dec = sensitivity_indices .size
139+
129140 n_var_dec = max (1 , n_var_dec ) # keep at least one variable
130141 n_var_dec = min (5 , n_var_dec ) # use at most 5 variables
131142 else :
132143 n_var_dec = inputs .shape [1 ]
133144
134- # 2. states formation
145+ # 2. variable selection and reordering
146+ if auto_ordering :
147+ var_names = var_names [var_order [:n_var_dec ]].tolist ()
148+ inputs = inputs [:, var_order [:n_var_dec ]]
149+ else :
150+ var_names = var_names [:n_var_dec ].tolist ()
151+ inputs = inputs [:, :n_var_dec ]
152+
153+ # 3. states formation (after reordering/selection)
135154 if states is None :
136- states = 3 if n_var_dec < 3 else 2
155+ states = 3 if n_var_dec <= 2 else 2
137156 states = [states ] * n_var_dec
138157
139158 for i in range (n_var_dec ):
140159 n_unique = np .unique (inputs [:, i ]).size
141160 states [i ] = n_unique if n_unique <= 5 else states [i ]
142161
143- if auto_ordering :
144- var_names = var_names [var_order [:n_var_dec ]].tolist ()
145- inputs = inputs [:, var_order [:n_var_dec ]]
146-
147- # 3. decomposition
162+ # 4. decomposition
148163 bins = []
149164
150165 statistic_methods = {
@@ -153,8 +168,8 @@ def decomposition(
153168 }
154169 try :
155170 statistic_method = statistic_methods [statistic ]
156- except IndexError :
157- msg = f"'statistic' must be one of { statistic_methods .values ()} "
171+ except KeyError :
172+ msg = f"'statistic' must be one of { statistic_methods .keys ()} "
158173 raise ValueError (msg )
159174
160175 def statistic_ (inputs ):
0 commit comments