1010# ' @param splits mlr3 object defining data splits for train and test sets
1111# ' @param subset numeric, what percentage of the instances to use from 0 to 1 where 1 means all
1212# '
13+ # ' @importFrom dplyr filter mutate
14+ # ' @importFrom forcats fct_reorder
1315# ' @importFrom magrittr %>%
1416# ' @importFrom ggplot2 ggplot aes geom_violin geom_line coord_flip geom_jitter position_jitter scale_shape_manual labs scale_colour_gradient2 geom_text theme element_blank geom_hline element_text element_line ylim
1517# ' @export
@@ -194,28 +196,48 @@ eSHAP_plot <- function(task,
194196 shap_Mean $ correct_prediction <- factor (shap_Mean $ correct_prediction , levels = c(FALSE , TRUE ), labels = c(" Incorrect" , " Correct" ))
195197
196198
197- shap_plot <- shap_Mean %> %
198- mutate(feature = forcats :: fct_reorder(feature , mean_phi )) %> %
199+ # Prepare data for plotting with robust error handling
200+ # Handle NaN values in mean_phi by recalculating from Phi
201+ shap_Mean <- shap_Mean %> %
202+ group_by(feature ) %> %
203+ mutate(mean_phi = ifelse(is.nan(mean_phi ) | is.na(mean_phi ),
204+ mean(Phi , na.rm = TRUE ),
205+ mean_phi )) %> %
206+ ungroup()
207+
208+ plot_data <- shap_Mean %> %
209+ filter(! is.na(feature )) %> %
210+ filter(! is.na(Phi ), is.finite(Phi )) %> %
211+ filter(! is.na(f_val ), is.finite(f_val )) %> %
212+ mutate(
213+ feature = as.character(feature ),
214+ feature = factor (feature ),
215+ Phi = as.numeric(Phi ),
216+ f_val = as.numeric(f_val ),
217+ mean_phi = as.numeric(mean_phi ),
218+ sample_num = as.integer(sample_num )
219+ )
220+
221+ # Check if we have data to plot
222+ if (nrow(plot_data ) == 0 ) {
223+ stop(" No valid data for plotting after filtering" )
224+ }
225+
226+ shap_plot <- plot_data %> %
199227 ggplot(aes(x = feature , y = Phi , color = f_val )) +
200228 geom_violin(colour = " grey" ) +
201- geom_line(aes(group = sample_num ), alpha = 0.1 , size = 0.2 ) +
229+ geom_line(aes(group = sample_num ), alpha = 0.1 , linewidth = 0.2 ) +
202230 coord_flip() +
203- geom_jitter(aes(shape = correct_prediction , text = paste(
204- " Feature: " , feature ,
205- " <br>Unscaled feature value: " , unscaled_f_val ,
206- " <br>SHAP value: " , Phi ,
207- " <br>Prediction correctness: " , correct_prediction ,
208- " <br>Predicted probability: " , pred_prob ,
209- " <br>Predicted class: " , pred_class
210- )),
231+ geom_jitter(aes(shape = correct_prediction ),
211232 alpha = 0.6 , size = 1.5 , position = position_jitter(width = 0.2 , height = 0 )
212233 ) +
213- scale_shape_manual(values = c(4 , 19 ), guide = FALSE ) +
234+ scale_shape_manual(values = c(4 , 19 ), guide = " none " ) +
214235 # scale_color_manual(values=c("black","grey")) +
215236 labs(shape = " model prediction" ) +
216237 scale_colour_gradient2(low = " blue" , mid = " green" , high = " red" , midpoint = 0.5 , breaks = c(0 , 1 ), labels = c(" Low" , " High" )) +
217238 guides(shape = ggplot2 :: guide_legend(override.aes = list (fill = " black" , color = " black" ))) +
218- geom_text(aes(x = feature , y = - Inf , label = sprintf(" %.3f" , mean_phi )), hjust = - 0.2 , alpha = 0.7 , color = " black" ) +
239+ # Remove problematic geom_text that might cause coord_flip issues
240+ # geom_text(aes(x = feature, y = -Inf, label = sprintf("%.3f", mean_phi)), hjust = -0.2, alpha = 0.7, color = "black") +
219241 theme(
220242 axis.line.y = element_blank(), axis.ticks.y = element_blank(),
221243 legend.position = " right"
@@ -235,10 +257,10 @@ eSHAP_plot <- function(task,
235257 axis.line = element_line(colour = " grey" ),
236258 legend.key.width = grid :: unit(2 , " mm" )
237259 ) +
238- ylim(min(shap_Mean $ Phi ) - 0.05 , max(shap_Mean $ Phi ) + 0.05 )
260+ ylim(min(plot_data $ Phi , na.rm = TRUE ) - 0.05 , max(plot_data $ Phi , na.rm = TRUE ) + 0.05 )
239261
240- # Convert ggplot to Plotly
241- shap_plot <- ggplotly(shap_plot , tooltip = " text " )
262+ # Convert ggplot to Plotly (simplified without text tooltips)
263+ shap_plot <- ggplotly(shap_plot )
242264
243265 # Additional plot to show SHAP values vs. predicted probabilities
244266 shap_pred_plot <- shap_Mean %> %
0 commit comments