|
| 1 | +from typing import List, Tuple, Union |
| 2 | +from numpy import linalg as LA |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import pandas as pd |
| 6 | +import torch |
| 7 | +import logging |
| 8 | +import yaml |
| 9 | +from data.data_object import DataObject |
| 10 | +from method.method_factory import register_method |
| 11 | +from evaluation.utils import check_counterfactuals |
| 12 | +from experiment_utils import deep_merge, reconstruct_encoding_constraints |
| 13 | +from method.method_object import MethodObject |
| 14 | +from model.catalog.autoencoder.vae import VariationalAutoencoder |
| 15 | +from model.model_object import ModelObject |
| 16 | + |
| 17 | +@register_method("CCHVAE") |
| 18 | +class CCHVAE(MethodObject): |
| 19 | + """ |
| 20 | + Implementation of CCHVAE [1]_ |
| 21 | +
|
| 22 | + .. [1] Pawelczyk, Martin, Klaus Broelemann and Gjergji Kasneci. “Learning Model-Agnostic Counterfactual Explanations |
| 23 | + for Tabular Data.” Proceedings of The Web Conference 2020 (2020): n. pag.. |
| 24 | + """ |
| 25 | + def __init__(self, data: DataObject, model: ModelObject, config_override = None): |
| 26 | + super().__init__(data, model, config_override) |
| 27 | + |
| 28 | + # get configs from config file |
| 29 | + self.config = yaml.safe_load(open("method/catalog/CCHVAE/library/config.yml", 'r')) |
| 30 | + |
| 31 | + # merge configs with user specified, if they exist |
| 32 | + if self._config_override is not None: |
| 33 | + self.config = deep_merge(self.config, self._config_override) |
| 34 | + |
| 35 | + self._feature_order = self._data.get_feature_names(expanded=True) |
| 36 | + |
| 37 | + self._n_search_samples = self.config["n_search_samples"] |
| 38 | + self._p_norm = self.config["p_norm"] |
| 39 | + self._step = self.config["step"] |
| 40 | + self._max_iter = self.config["max_iter"] |
| 41 | + self._clamp = self.config["clamp"] |
| 42 | + |
| 43 | + vae_params = self.config["vae_params"] |
| 44 | + self._vae = VariationalAutoencoder( |
| 45 | + data_name = self.config['data_name'] if self.config['data_name'] else "Temp", |
| 46 | + layers=[int(sum(self._model.get_mutable_mask()))] + vae_params['layers'], |
| 47 | + mutable_mask=self._model.get_mutable_mask(), |
| 48 | + ) |
| 49 | + |
| 50 | + if vae_params["train"]: |
| 51 | + self._vae.fit( |
| 52 | + xtrain=self._model.get_train_data()[0][self._feature_order], |
| 53 | + kl_weight=vae_params["kl_weight"], |
| 54 | + lambda_reg=vae_params["lambda_reg"], |
| 55 | + epochs=vae_params["epochs"], |
| 56 | + lr=vae_params["lr"], |
| 57 | + batch_size=vae_params["batch_size"], |
| 58 | + ) |
| 59 | + else: |
| 60 | + try: |
| 61 | + self._vae.load(vae_params["layers"][0]) |
| 62 | + except FileNotFoundError as exc: |
| 63 | + raise FileNotFoundError( |
| 64 | + "Loading of Autoencoder failed. {}".format(str(exc)) |
| 65 | + ) |
| 66 | + |
| 67 | + def _hyper_sphere_coordindates( |
| 68 | + self, instance, high: int, low: int |
| 69 | + ) -> Tuple[np.ndarray, np.ndarray]: |
| 70 | + """ |
| 71 | + :param n_search_samples: int > 0 |
| 72 | + :param instance: numpy input point array |
| 73 | + :param high: float>= 0, h>l; upper bound |
| 74 | + :param low: float>= 0, l<h; lower bound |
| 75 | + :param p: float>= 1; norm |
| 76 | + :return: candidate counterfactuals & distances |
| 77 | + """ |
| 78 | + delta_instance = np.random.randn(self._n_search_samples, instance.shape[1]) |
| 79 | + dist = ( |
| 80 | + np.random.rand(self._n_search_samples) * (high - low) + low |
| 81 | + ) # length range [l, h) |
| 82 | + norm_p = LA.norm(delta_instance, ord=self._p_norm, axis=1) |
| 83 | + d_norm = np.divide(dist, norm_p).reshape(-1, 1) # rescale/normalize factor |
| 84 | + delta_instance = np.multiply(delta_instance, d_norm) |
| 85 | + candidate_counterfactuals = instance + delta_instance |
| 86 | + return candidate_counterfactuals, dist |
| 87 | + |
| 88 | + def _counterfactual_search( |
| 89 | + self, step: int, factual: torch.Tensor, cat_features_indices: List[list[int]] |
| 90 | + ) -> pd.DataFrame: |
| 91 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 92 | + |
| 93 | + # init step size for growing the sphere |
| 94 | + low = 0 |
| 95 | + high = step |
| 96 | + # counter |
| 97 | + count = 0 |
| 98 | + counter_step = 1 |
| 99 | + |
| 100 | + torch_fact = torch.from_numpy(factual).to(device) |
| 101 | + |
| 102 | + # get predicted label of instance |
| 103 | + instance_label = np.argmax( |
| 104 | + self._model.predict_proba(torch_fact.float()).cpu().detach().numpy(), |
| 105 | + axis=1, |
| 106 | + ) |
| 107 | + |
| 108 | + # vectorize z |
| 109 | + z = self._vae.encode( |
| 110 | + torch_fact[:, self._vae.mutable_mask].float() |
| 111 | + )[0] |
| 112 | + # add the immutable features to the latents |
| 113 | + z = torch.cat([z, torch_fact[:, ~self._vae.mutable_mask]], dim=-1) |
| 114 | + z = z.cpu().detach().numpy() |
| 115 | + z_rep = np.repeat(z.reshape(1, -1), self._n_search_samples, axis=0) |
| 116 | + |
| 117 | + # make copy such that we later easily combine the immutables and the reconstructed mutables |
| 118 | + fact_rep = torch_fact.reshape(1, -1).repeat_interleave( |
| 119 | + self._n_search_samples, dim=0 |
| 120 | + ) |
| 121 | + |
| 122 | + candidate_dist: List = [] |
| 123 | + x_ce: Union[np.ndarray, torch.Tensor] = np.array([]) |
| 124 | + while count <= self._max_iter or len(candidate_dist) <= 0: |
| 125 | + count = count + counter_step |
| 126 | + if count > self._max_iter: |
| 127 | + logging.debug("No counterfactual example found") |
| 128 | + return x_ce[0] |
| 129 | + |
| 130 | + # STEP 1 -- SAMPLE POINTS on hyper sphere around instance |
| 131 | + latent_neighbourhood, _ = self._hyper_sphere_coordindates(z_rep, high, low) |
| 132 | + torch_latent_neighbourhood = ( |
| 133 | + torch.from_numpy(latent_neighbourhood).to(device).float() |
| 134 | + ) |
| 135 | + x_ce = self._vae.decode(torch_latent_neighbourhood) |
| 136 | + |
| 137 | + # add the immutable features to the reconstruction |
| 138 | + temp = fact_rep.clone() |
| 139 | + temp[:, self._vae.mutable_mask] = x_ce.float() |
| 140 | + x_ce = temp |
| 141 | + |
| 142 | + x_ce = reconstruct_encoding_constraints( |
| 143 | + x_ce, cat_features_indices |
| 144 | + ) |
| 145 | + x_ce = x_ce.detach().cpu().numpy() |
| 146 | + x_ce = x_ce.clip(0, 1) if self._clamp else x_ce |
| 147 | + |
| 148 | + # STEP 2 -- COMPUTE l1 & l2 norms |
| 149 | + if self._p_norm == 1: |
| 150 | + distances = np.abs((x_ce - torch_fact.cpu().detach().numpy())).sum( |
| 151 | + axis=1 |
| 152 | + ) |
| 153 | + elif self._p_norm == 2: |
| 154 | + distances = LA.norm(x_ce - torch_fact.cpu().detach().numpy(), axis=1) |
| 155 | + else: |
| 156 | + raise ValueError("Possible values for p_norm are 1 or 2") |
| 157 | + |
| 158 | + # counterfactual labels |
| 159 | + y_candidate = np.argmax( |
| 160 | + self._model.predict_proba(torch.from_numpy(x_ce).float()) |
| 161 | + .cpu() |
| 162 | + .detach() |
| 163 | + .numpy(), |
| 164 | + axis=1, |
| 165 | + ) |
| 166 | + indices = np.where(y_candidate != instance_label) |
| 167 | + candidate_counterfactuals = x_ce[indices] |
| 168 | + candidate_dist = distances[indices] |
| 169 | + # no candidate found & push search range outside |
| 170 | + if len(candidate_dist) == 0: |
| 171 | + low = high |
| 172 | + high = low + step |
| 173 | + elif len(candidate_dist) > 0: |
| 174 | + # certain candidates generated |
| 175 | + min_index = np.argmin(candidate_dist) |
| 176 | + logging.debug("Counterfactual example found") |
| 177 | + return candidate_counterfactuals[min_index] |
| 178 | + |
| 179 | + def get_counterfactuals(self, factuals: pd.DataFrame) -> pd.DataFrame: |
| 180 | + factuals = factuals[self._feature_order] # ensure the feature ordering is correct for the model input |
| 181 | + |
| 182 | + # pay attention to categorical features |
| 183 | + encoded_feature_names = self._data.get_categorical_features(expanded=True) |
| 184 | + |
| 185 | + cat_features_indices = [] |
| 186 | + for features in encoded_feature_names: |
| 187 | + # Find the indices of these encoded features in the processed dataframe |
| 188 | + indices = [factuals.columns.get_loc(feat) for feat in features] |
| 189 | + cat_features_indices.append(indices) |
| 190 | + |
| 191 | + df_cfs = factuals.apply( |
| 192 | + lambda x: self._counterfactual_search( |
| 193 | + self._step, x.reshape((1, -1)), cat_features_indices |
| 194 | + ), |
| 195 | + raw=True, |
| 196 | + axis=1, |
| 197 | + ) |
| 198 | + |
| 199 | + df_cfs = check_counterfactuals(self._model, self._data, df_cfs, factuals.index) |
| 200 | + return df_cfs |
| 201 | + |
0 commit comments