Skip to content

Commit acc0f66

Browse files
committed
Fix sensitivity indices to match Matlab SOE calculation.
1 parent bf93d74 commit acc0f66

1 file changed

Lines changed: 51 additions & 37 deletions

File tree

src/simdec/sensitivity_indices.py

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from dataclasses import dataclass
2-
import warnings
32

43
import numpy as np
54
import pandas as pd
@@ -61,7 +60,7 @@ def sensitivity_indices(
6160
Sensitivity indices, combined effect of each input.
6261
foe : ndarray of shape (n_factors, 1)
6362
First-order effects (also called 'main' or 'individual').
64-
soe : ndarray of shape (n_factors, 1)
63+
soe_full : ndarray of shape (n_factors, 1)
6564
Second-order effects (also called 'interaction').
6665
6766
Examples
@@ -96,15 +95,21 @@ def sensitivity_indices(
9695
array([0.43157591, 0.44241433, 0.11767249])
9796
9897
"""
98+
# Handle inputs conversion
9999
if isinstance(inputs, pd.DataFrame):
100100
cat_columns = inputs.select_dtypes(["category", "O"]).columns
101101
inputs[cat_columns] = inputs[cat_columns].apply(
102102
lambda x: x.astype("category").cat.codes
103103
)
104104
inputs = inputs.to_numpy()
105-
if isinstance(output, pd.DataFrame):
105+
106+
# Handle output conversion first, then flatten
107+
if isinstance(output, (pd.DataFrame, pd.Series)):
106108
output = output.to_numpy()
107109

110+
# Flatten output if it's (N, 1)
111+
output = output.flatten()
112+
108113
n_runs, n_factors = inputs.shape
109114
n_bins_foe, n_bins_soe = number_of_bins(n_runs, n_factors)
110115

@@ -116,55 +121,64 @@ def sensitivity_indices(
116121
soe = np.zeros((n_factors, n_factors))
117122

118123
for i in range(n_factors):
119-
# first order
124+
# 1. First-order effects (FOE)
120125
xi = inputs[:, i]
121126

122127
bin_avg, _, binnumber = stats.binned_statistic(
123-
x=xi, values=output, bins=n_bins_foe
128+
x=xi, values=output, bins=n_bins_foe, statistic="mean"
124129
)
125-
# can have NaN in the average but no corresponding binnumber
126-
bin_avg = bin_avg[~np.isnan(bin_avg)]
127-
bin_counts = np.unique(binnumber, return_counts=True)[1]
128130

129-
# weighted variance and divide by the overall variance of the output
130-
foe[i] = _weighted_var(bin_avg, weights=bin_counts) / var_y
131+
# Filter empty bins and get weights (counts)
132+
mask_foe = ~np.isnan(bin_avg)
133+
mean_i_foe = bin_avg[mask_foe]
134+
# binnumber starts at 1; 0 is for values outside range
135+
bin_counts_foe = np.unique(binnumber[binnumber > 0], return_counts=True)[1]
131136

132-
# second order
137+
foe[i] = _weighted_var(mean_i_foe, weights=bin_counts_foe) / var_y
138+
139+
# 2. Second-order effects (SOE)
133140
for j in range(n_factors):
134-
if i == j or j < i:
141+
if j <= i:
135142
continue
136143

137144
xj = inputs[:, j]
138145

139-
bin_avg, *edges, binnumber = stats.binned_statistic_2d(
146+
# 2D Binned Statistic for Var(E[Y|Xi, Xj])
147+
bin_avg_ij, x_edges, y_edges, binnumber_ij = stats.binned_statistic_2d(
140148
x=xi, y=xj, values=output, bins=n_bins_soe, expand_binnumbers=False
141149
)
142150

143-
mean_ij = bin_avg[~np.isnan(bin_avg)]
144-
bin_counts = np.unique(binnumber, return_counts=True)[1]
145-
var_ij = _weighted_var(mean_ij, weights=bin_counts)
146-
147-
# expand_binnumbers here
148-
nbin = np.array([len(edges_) + 1 for edges_ in edges])
149-
binnumbers = np.asarray(np.unravel_index(binnumber, nbin))
151+
mask_ij = ~np.isnan(bin_avg_ij)
152+
mean_ij = bin_avg_ij[mask_ij]
153+
counts_ij = np.unique(binnumber_ij[binnumber_ij > 0], return_counts=True)[1]
154+
var_ij = _weighted_var(mean_ij, weights=counts_ij)
150155

151-
bin_counts_i = np.unique(binnumbers[0], return_counts=True)[1]
152-
bin_counts_j = np.unique(binnumbers[1], return_counts=True)[1]
153-
154-
# handle NaNs
155-
with warnings.catch_warnings():
156-
warnings.simplefilter("ignore", RuntimeWarning)
157-
mean_i = np.nanmean(bin_avg, axis=1)
158-
mean_i = mean_i[~np.isnan(mean_i)]
159-
mean_j = np.nanmean(bin_avg, axis=0)
160-
mean_j = mean_j[~np.isnan(mean_j)]
161-
162-
var_i = _weighted_var(mean_i, weights=bin_counts_i)
163-
var_j = _weighted_var(mean_j, weights=bin_counts_j)
156+
# Marginal Var(E[Y|Xi]) using n_bins_soe to match MATLAB logic
157+
bin_avg_i_soe, _, binnumber_i_soe = stats.binned_statistic(
158+
x=xi, values=output, bins=n_bins_soe, statistic="mean"
159+
)
160+
mask_i = ~np.isnan(bin_avg_i_soe)
161+
counts_i = np.unique(
162+
binnumber_i_soe[binnumber_i_soe > 0], return_counts=True
163+
)[1]
164+
var_i_soe = _weighted_var(bin_avg_i_soe[mask_i], weights=counts_i)
165+
166+
# Marginal Var(E[Y|Xj]) using n_bins_soe to match MATLAB logic
167+
bin_avg_j_soe, _, binnumber_j_soe = stats.binned_statistic(
168+
x=xj, values=output, bins=n_bins_soe, statistic="mean"
169+
)
170+
mask_j = ~np.isnan(bin_avg_j_soe)
171+
counts_j = np.unique(
172+
binnumber_j_soe[binnumber_j_soe > 0], return_counts=True
173+
)[1]
174+
var_j_soe = _weighted_var(bin_avg_j_soe[mask_j], weights=counts_j)
164175

165-
soe[i, j] = (var_ij - var_i - var_j) / var_y
176+
soe[i, j] = (var_ij - var_i_soe - var_j_soe) / var_y
166177

167-
soe = np.where(soe == 0, soe.T, soe)
168-
si[i] = foe[i] + soe[:, i].sum() / 2
178+
# Mirror SOE and calculate Combined Effect (SI)
179+
# SI is FOE + half of all interactions associated with that variable
180+
soe_full = soe + soe.T
181+
for k in range(n_factors):
182+
si[k] = foe[k] + (soe_full[:, k].sum() / 2)
169183

170-
return SensitivityAnalysisResult(si, foe, soe)
184+
return SensitivityAnalysisResult(si, foe, soe_full)

0 commit comments

Comments
 (0)