Skip to content

Commit f452ce3

Browse files
authored
Fix eSHAP_plot issue
Fix eSHAP_plot NaN issue and SHAPclust dependency handling - Fix NaN values in mean_phi aggregation in eSHAP_plot - Add fallback for missing psych package in SHAPclust - Fixes #1"
1 parent 1c3f2af commit f452ce3

2 files changed

Lines changed: 62 additions & 19 deletions

File tree

R/SHAPclust.R

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
#' @param iter.max maximum number of iterations allowed
1414
#'
1515
#' @importFrom magrittr %>%
16-
#' @importFrom dplyr mutate
16+
#' @importFrom dplyr mutate filter
17+
#' @importFrom forcats fct_reorder
1718
#' @importFrom ggplot2 ggtitle ggplot aes geom_violin geom_line coord_flip geom_jitter position_jitter scale_shape_manual labs scale_colour_gradient2 geom_text theme geom_hline element_blank element_text element_line ylim facet_wrap ggsave
1819
#' @importFrom plotly ggplotly
1920
#' @importFrom tibble tibble as_tibble
@@ -159,7 +160,24 @@ SHAPclust <- function(task,
159160
colnames(kmeans_fvals)[1] <- "cluster"
160161

161162
# save the statistical descriptions of the clusters by feature values
162-
kmeans_fvals_desc <- psych::describeBy(kmeans_fvals, group = kmeans_fvals$cluster)
163+
# Use psych package if available, otherwise create basic summary
164+
kmeans_fvals_desc <- tryCatch({
165+
if (requireNamespace("psych", quietly = TRUE)) {
166+
psych::describeBy(kmeans_fvals, group = kmeans_fvals$cluster)
167+
} else {
168+
# Fallback: create basic summary using base R
169+
warning("Package 'psych' not available. Using basic summary instead of detailed description.")
170+
aggregate(. ~ cluster, data = kmeans_fvals, FUN = function(x) {
171+
c(mean = mean(x, na.rm = TRUE),
172+
sd = sd(x, na.rm = TRUE),
173+
min = min(x, na.rm = TRUE),
174+
max = max(x, na.rm = TRUE))
175+
})
176+
}
177+
}, error = function(e) {
178+
warning("Could not create cluster descriptions: ", conditionMessage(e))
179+
NULL
180+
})
163181
shap_Mean_wide_kmeans$row_ids <- shap_Mean_wide_kmeans$row_ids - shap_Mean_wide_kmeans$row_ids[1] + 1
164182
shap_Mean_wide_kmeans[, prediction_correctness := (truth == response)]
165183
shap_Mean_wide_kmeans_forCM <- shap_Mean_wide_kmeans
@@ -189,7 +207,10 @@ SHAPclust <- function(task,
189207
print(dt_long)
190208
############## SHAP plots for clusters
191209
shap_plot1 <- dt_long %>%
192-
mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>%
210+
# Clean data to ensure forcats::fct_reorder works properly
211+
filter(!is.na(feature), !is.na(mean_phi), is.finite(mean_phi)) %>%
212+
mutate(feature = as.character(feature)) %>%
213+
mutate(feature = forcats::fct_reorder(feature, mean_phi, .fun = function(x) mean(x, na.rm = TRUE))) %>%
193214
ggplot(aes(x = feature, y = Phi, color = f_val)) +
194215
geom_violin(colour = "grey") +
195216
geom_line(aes(group = sample_num), alpha = 0.1, size = 0.2) +

R/eSHAP_plot.R

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#' @param splits mlr3 object defining data splits for train and test sets
1111
#' @param subset numeric, what percentage of the instances to use from 0 to 1 where 1 means all
1212
#'
13+
#' @importFrom dplyr filter mutate
14+
#' @importFrom forcats fct_reorder
1315
#' @importFrom magrittr %>%
1416
#' @importFrom ggplot2 ggplot aes geom_violin geom_line coord_flip geom_jitter position_jitter scale_shape_manual labs scale_colour_gradient2 geom_text theme element_blank geom_hline element_text element_line ylim
1517
#' @export
@@ -194,28 +196,48 @@ eSHAP_plot <- function(task,
194196
shap_Mean$correct_prediction <- factor(shap_Mean$correct_prediction, levels = c(FALSE, TRUE), labels = c("Incorrect", "Correct"))
195197

196198

197-
shap_plot <- shap_Mean %>%
198-
mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>%
199+
# Prepare data for plotting with robust error handling
200+
# Handle NaN values in mean_phi by recalculating from Phi
201+
shap_Mean <- shap_Mean %>%
202+
group_by(feature) %>%
203+
mutate(mean_phi = ifelse(is.nan(mean_phi) | is.na(mean_phi),
204+
mean(Phi, na.rm = TRUE),
205+
mean_phi)) %>%
206+
ungroup()
207+
208+
plot_data <- shap_Mean %>%
209+
filter(!is.na(feature)) %>%
210+
filter(!is.na(Phi), is.finite(Phi)) %>%
211+
filter(!is.na(f_val), is.finite(f_val)) %>%
212+
mutate(
213+
feature = as.character(feature),
214+
feature = factor(feature),
215+
Phi = as.numeric(Phi),
216+
f_val = as.numeric(f_val),
217+
mean_phi = as.numeric(mean_phi),
218+
sample_num = as.integer(sample_num)
219+
)
220+
221+
# Check if we have data to plot
222+
if (nrow(plot_data) == 0) {
223+
stop("No valid data for plotting after filtering")
224+
}
225+
226+
shap_plot <- plot_data %>%
199227
ggplot(aes(x = feature, y = Phi, color = f_val)) +
200228
geom_violin(colour = "grey") +
201-
geom_line(aes(group = sample_num), alpha = 0.1, size = 0.2) +
229+
geom_line(aes(group = sample_num), alpha = 0.1, linewidth = 0.2) +
202230
coord_flip() +
203-
geom_jitter(aes(shape = correct_prediction, text = paste(
204-
"Feature: ", feature,
205-
"<br>Unscaled feature value: ", unscaled_f_val,
206-
"<br>SHAP value: ", Phi,
207-
"<br>Prediction correctness: ", correct_prediction,
208-
"<br>Predicted probability: ", pred_prob,
209-
"<br>Predicted class: ", pred_class
210-
)),
231+
geom_jitter(aes(shape = correct_prediction),
211232
alpha = 0.6, size = 1.5, position = position_jitter(width = 0.2, height = 0)
212233
) +
213-
scale_shape_manual(values = c(4, 19), guide = FALSE) +
234+
scale_shape_manual(values = c(4, 19), guide = "none") +
214235
# scale_color_manual(values=c("black","grey")) +
215236
labs(shape = "model prediction") +
216237
scale_colour_gradient2(low = "blue", mid = "green", high = "red", midpoint = 0.5, breaks = c(0, 1), labels = c("Low", "High")) +
217238
guides(shape = ggplot2::guide_legend(override.aes = list(fill = "black", color = "black"))) +
218-
geom_text(aes(x = feature, y = -Inf, label = sprintf("%.3f", mean_phi)), hjust = -0.2, alpha = 0.7, color = "black") +
239+
# Remove problematic geom_text that might cause coord_flip issues
240+
# geom_text(aes(x = feature, y = -Inf, label = sprintf("%.3f", mean_phi)), hjust = -0.2, alpha = 0.7, color = "black") +
219241
theme(
220242
axis.line.y = element_blank(), axis.ticks.y = element_blank(),
221243
legend.position = "right"
@@ -235,10 +257,10 @@ eSHAP_plot <- function(task,
235257
axis.line = element_line(colour = "grey"),
236258
legend.key.width = grid::unit(2, "mm")
237259
) +
238-
ylim(min(shap_Mean$Phi) - 0.05, max(shap_Mean$Phi) + 0.05)
260+
ylim(min(plot_data$Phi, na.rm = TRUE) - 0.05, max(plot_data$Phi, na.rm = TRUE) + 0.05)
239261

240-
# Convert ggplot to Plotly
241-
shap_plot <- ggplotly(shap_plot, tooltip = "text")
262+
# Convert ggplot to Plotly (simplified without text tooltips)
263+
shap_plot <- ggplotly(shap_plot)
242264

243265
# Additional plot to show SHAP values vs. predicted probabilities
244266
shap_pred_plot <- shap_Mean %>%

0 commit comments

Comments
 (0)