Skip to content

Commit 6495b61

Browse files
committed
Fix variable selection logic and match MATLAB state formation
1 parent acc0f66 commit 6495b61

1 file changed

Lines changed: 28 additions & 13 deletions

File tree

src/simdec/decomposition.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def decomposition(
6868
output: pd.DataFrame,
6969
*,
7070
sensitivity_indices: np.ndarray,
71-
dec_limit: float = 1,
71+
dec_limit: float | None = None,
7272
auto_ordering: bool = True,
7373
states: list[int] | None = None,
7474
statistic: Literal["mean", "median"] | None = "mean",
@@ -79,7 +79,7 @@ def decomposition(
7979
----------
8080
inputs : DataFrame of shape (n_runs, n_factors)
8181
Input variables.
82-
output : DataFrame of shape (n_runs, 1)
82+
output : DataFrame of shape (n_runs, 1) or (n_runs,)
8383
Target variable.
8484
sensitivity_indices : ndarray of shape (n_factors, 1)
8585
Sensitivity indices, combined effect of each input.
@@ -116,7 +116,7 @@ def decomposition(
116116
inputs[cat_col] = codes
117117

118118
inputs = inputs.to_numpy()
119-
output = output.to_numpy()
119+
output = output.to_numpy().flatten()
120120

121121
# 1. variables for decomposition
122122
var_order = np.argsort(sensitivity_indices)[::-1]
@@ -125,26 +125,41 @@ def decomposition(
125125
sensitivity_indices = sensitivity_indices[var_order]
126126

127127
if auto_ordering:
128-
n_var_dec = np.where(np.cumsum(sensitivity_indices) < dec_limit)[0].size
128+
# handle edge case where sensitivity indices don't sum exactly to 1.0
129+
if dec_limit is None:
130+
dec_limit = 0.8 * np.sum(sensitivity_indices)
131+
132+
cumulative_sum = np.cumsum(sensitivity_indices)
133+
indices_over_limit = np.where(cumulative_sum >= dec_limit)[0]
134+
135+
if indices_over_limit.size > 0:
136+
n_var_dec = indices_over_limit[0] + 1
137+
else:
138+
n_var_dec = sensitivity_indices.size
139+
129140
n_var_dec = max(1, n_var_dec) # keep at least one variable
130141
n_var_dec = min(5, n_var_dec) # use at most 5 variables
131142
else:
132143
n_var_dec = inputs.shape[1]
133144

134-
# 2. states formation
145+
# 2. variable selection and reordering
146+
if auto_ordering:
147+
var_names = var_names[var_order[:n_var_dec]].tolist()
148+
inputs = inputs[:, var_order[:n_var_dec]]
149+
else:
150+
var_names = var_names[:n_var_dec].tolist()
151+
inputs = inputs[:, :n_var_dec]
152+
153+
# 3. states formation (after reordering/selection)
135154
if states is None:
136-
states = 3 if n_var_dec < 3 else 2
155+
states = 3 if n_var_dec <= 2 else 2
137156
states = [states] * n_var_dec
138157

139158
for i in range(n_var_dec):
140159
n_unique = np.unique(inputs[:, i]).size
141160
states[i] = n_unique if n_unique <= 5 else states[i]
142161

143-
if auto_ordering:
144-
var_names = var_names[var_order[:n_var_dec]].tolist()
145-
inputs = inputs[:, var_order[:n_var_dec]]
146-
147-
# 3. decomposition
162+
# 4. decomposition
148163
bins = []
149164

150165
statistic_methods = {
@@ -153,8 +168,8 @@ def decomposition(
153168
}
154169
try:
155170
statistic_method = statistic_methods[statistic]
156-
except IndexError:
157-
msg = f"'statistic' must be one of {statistic_methods.values()}"
171+
except KeyError:
172+
msg = f"'statistic' must be one of {statistic_methods.keys()}"
158173
raise ValueError(msg)
159174

160175
def statistic_(inputs):

0 commit comments

Comments
 (0)