Skip to content

Make explain_tidymodels avoid covariates with certain roles (defined by recipes) #92

@ajunquera

Description

@ajunquera

Hello!

Thanks for the package, it's very useful. I'm using explain_tidymodels with model_parts for a classification task. I would like model_parts not to consider certain variables declared with a specific role with recipes. In my case, it's specially important not to consider that variable with role as a predictor, since it's a numerical version of the target variable. My concern is that the importance scores have been calculated with this numerical target variable as a predictor of the binary target variable, with would obviously make the rest of predictors "no important".

I have tried these two solutions proposed in #65, but none works...

# Problem with ROLES in DALEXtra

## Option A: recipes --------

library(DALEXtra)
#> Warning: package 'DALEXtra' was built under R version 4.4.3
#> Cargando paquete requerido: DALEX
#> Warning: package 'DALEX' was built under R version 4.4.3
#> Welcome to DALEX (version: 2.4.3).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/

library(tidymodels)

dataset <- titanic_imputed
dataset$survived <- as.factor(dataset$survived)

rec <- recipe(survived ~ ., data = dataset) %>%
  update_role(parch, new_role = "test_role") %>%
  step_normalize(fare) %>%
  step_dummy(all_nominal_predictors())

model <- decision_tree(tree_depth = 25) %>%
  set_engine("rpart") %>%
  set_mode("classification")

wflow <- workflow() %>%
  add_recipe(rec) %>%
  add_model(model)

model_fitted <- wflow %>%
  fit(data = dataset)

explainer <- explain_tidymodels(
  model_fitted, 
  data = rec %>% prep() %>% bake(new_data = NULL, all_predictors(), all_outcomes()), 
  y = titanic_imputed$survived
)
#> Preparation of a new explainer is initiated
#>   -> model label       :  workflow  (  default  )
#>   -> data              :  2207  rows  14  cols 
#>   -> data              :  tibble converted into a data.frame 
#>   -> target variable   :  2207  values 
#>   -> predict function  :  yhat.workflow  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package tidymodels , ver. 1.2.0 , task classification (  default  ) 
#>   -> predicted values  :  the predict_function returns an error when executed (  WARNING  ) 
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  the residual_function returns an error when executed (  WARNING  ) 
#>   A new explainer has been created!

model_parts(explainer) 
#> Error in `validate_column_names()`:
#> ! The following required columns are missing: 'gender', 'class', 'embarked'.

Here my suspiction is that bake() creates the design matrix (with one-hot encoding as mandated by step_dummy()), but the function requires the original variable.

An alternative is removing these variables manually, which does not work either.

## Option B: remove parch manually --------
explainer2 <- explain_tidymodels(
  model_fitted, 
  data = dataset %>% select(-parch), 
  y = titanic_imputed$survived
)
#> Preparation of a new explainer is initiated
#>   -> model label       :  workflow  (  default  )
#>   -> data              :  2207  rows  7  cols 
#>   -> target variable   :  2207  values 
#>   -> predict function  :  yhat.workflow  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package tidymodels , ver. 1.2.0 , task classification (  default  ) 
#>   -> predicted values  :  the predict_function returns an error when executed (  WARNING  ) 
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  the residual_function returns an error when executed (  WARNING  ) 
#>   A new explainer has been created!

model_parts(explainer2) 
#> Error in `validate_column_names()`:
#> ! The following required columns are missing: 'parch'.

Created on 2025-07-17 with reprex v2.1.1

My use case is slightly different because I would like to estimate the importance scores with the test set. I have tried bake(new_data = testdata, all_predictors(), all_outcomes()) and testdata %>% select(-parch) and did not work either...

Thanks in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions