Skip to content

Commit 7274585

Browse files
committed
Added CCHVAE
The method is working but needs a reproduce
1 parent bf91577 commit 7274585

5 files changed

Lines changed: 269 additions & 2 deletions

File tree

experiment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import method.catalog.ClaPROAR.method # noqa: F401
2424
import method.catalog.REVISE.method # noqa: F401
2525
import method.catalog.GRAVITATIONAL.method # noqa: F401
26+
import method.catalog.CCHVAE.method # noqa: F401
2627
import evaluation.catalog.distances # noqa: F401
2728
import evaluation.catalog.validity # noqa: F401
2829

@@ -56,6 +57,7 @@
5657
"ClaPROAR": "method/catalog/ClaPROAR/library/config.yml",
5758
"REVISE": "method/catalog/REVISE/library/config.yml",
5859
"GRAVITATIONAL" : "method/catalog/GRAVITATIONAL/library/config.yml",
60+
"CCHVAE" : "method/catalog/CCHVAE/library/config.yml",
5961
# add more method types and their config paths here
6062
}
6163

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# ============================================================
2+
# Top-Level Experiment Configuration
3+
# ============================================================
4+
# This file is the ONLY thing a user needs to create/modify.
5+
# All other layer configs can be overridden from here.
6+
7+
experiment:
8+
name: "german_cchvae_mlp_experiment"
9+
seed: 42
10+
num_factuals: 5 # how many negative-class samples to generate counterfactuals for
11+
factual_selection: "negative_class" # Options: "negative_class", "all"
12+
output_dir: "./results"
13+
save_results: true
14+
output_format: "csv" # Options: "csv", "json", "both"
15+
logger: "info" # Options: "debug", "info", "warning", "error"
16+
17+
# ---------- Data Layer ----------
18+
data:
19+
- name: "german"
20+
# Override any key inside the data config without editing it directly.
21+
# Keys here are merged ON TOP of whatever is in data_config_<name>.yml.
22+
overrides:
23+
train_split: 0.8
24+
preprocessing_strategy: "normalize" # Options: "normalize", "standardize", "none"
25+
features:
26+
age:
27+
type: "numerical"
28+
node_type: "input"
29+
actionability: "same-or-increase"
30+
mutability: False
31+
32+
# ---------- Model Layer ----------
33+
model:
34+
name: "mlp" # Options: "mlp", "logistic_regression", etc.
35+
overrides:
36+
37+
# ---------- Method Layer ----------
38+
method:
39+
name: "CCHVAE" # Options: "ROAR", "PROBE", etc.
40+
overrides:
41+
42+
# ---------- Evaluation Layer ----------
43+
evaluation:
44+
metrics:
45+
- name: "Distance"
46+
# hyperparameters: {} # Optional per-metric hyperparameters
47+
# Future: you could add more metric objects here
48+
# - name: "Sparsity"
49+
# - name: "Validity"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
data_name: null
2+
n_search_samples: 300
3+
p_norm: 1
4+
step: 0.1
5+
max_iter: 1000
6+
clamp: True
7+
binary_cat_features: True
8+
vae_params:
9+
layers: [512, 256, 8]
10+
train: True
11+
kl_weight: 0.3
12+
lambda_reg: 0.000001
13+
epochs: 5
14+
lr: 0.001
15+
batch_size: 32
16+

method/catalog/CCHVAE/method.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
from typing import List, Tuple, Union
2+
from numpy import linalg as LA
3+
4+
import numpy as np
5+
import pandas as pd
6+
import torch
7+
import logging
8+
import yaml
9+
from data.data_object import DataObject
10+
from method.method_factory import register_method
11+
from evaluation.utils import check_counterfactuals
12+
from experiment_utils import deep_merge, reconstruct_encoding_constraints
13+
from method.method_object import MethodObject
14+
from model.catalog.autoencoder.vae import VariationalAutoencoder
15+
from model.model_object import ModelObject
16+
17+
@register_method("CCHVAE")
18+
class CCHVAE(MethodObject):
19+
"""
20+
Implementation of CCHVAE [1]_
21+
22+
.. [1] Pawelczyk, Martin, Klaus Broelemann and Gjergji Kasneci. “Learning Model-Agnostic Counterfactual Explanations
23+
for Tabular Data.” Proceedings of The Web Conference 2020 (2020): n. pag..
24+
"""
25+
def __init__(self, data: DataObject, model: ModelObject, config_override = None):
26+
super().__init__(data, model, config_override)
27+
28+
# get configs from config file
29+
self.config = yaml.safe_load(open("method/catalog/CCHVAE/library/config.yml", 'r'))
30+
31+
# merge configs with user specified, if they exist
32+
if self._config_override is not None:
33+
self.config = deep_merge(self.config, self._config_override)
34+
35+
self._feature_order = self._data.get_feature_names(expanded=True)
36+
37+
self._n_search_samples = self.config["n_search_samples"]
38+
self._p_norm = self.config["p_norm"]
39+
self._step = self.config["step"]
40+
self._max_iter = self.config["max_iter"]
41+
self._clamp = self.config["clamp"]
42+
43+
vae_params = self.config["vae_params"]
44+
self._vae = VariationalAutoencoder(
45+
data_name = self.config['data_name'] if self.config['data_name'] else "Temp",
46+
layers=[int(sum(self._model.get_mutable_mask()))] + vae_params['layers'],
47+
mutable_mask=self._model.get_mutable_mask(),
48+
)
49+
50+
if vae_params["train"]:
51+
self._vae.fit(
52+
xtrain=self._model.get_train_data()[0][self._feature_order],
53+
kl_weight=vae_params["kl_weight"],
54+
lambda_reg=vae_params["lambda_reg"],
55+
epochs=vae_params["epochs"],
56+
lr=vae_params["lr"],
57+
batch_size=vae_params["batch_size"],
58+
)
59+
else:
60+
try:
61+
self._vae.load(vae_params["layers"][0])
62+
except FileNotFoundError as exc:
63+
raise FileNotFoundError(
64+
"Loading of Autoencoder failed. {}".format(str(exc))
65+
)
66+
67+
def _hyper_sphere_coordindates(
68+
self, instance, high: int, low: int
69+
) -> Tuple[np.ndarray, np.ndarray]:
70+
"""
71+
:param n_search_samples: int > 0
72+
:param instance: numpy input point array
73+
:param high: float>= 0, h>l; upper bound
74+
:param low: float>= 0, l<h; lower bound
75+
:param p: float>= 1; norm
76+
:return: candidate counterfactuals & distances
77+
"""
78+
delta_instance = np.random.randn(self._n_search_samples, instance.shape[1])
79+
dist = (
80+
np.random.rand(self._n_search_samples) * (high - low) + low
81+
) # length range [l, h)
82+
norm_p = LA.norm(delta_instance, ord=self._p_norm, axis=1)
83+
d_norm = np.divide(dist, norm_p).reshape(-1, 1) # rescale/normalize factor
84+
delta_instance = np.multiply(delta_instance, d_norm)
85+
candidate_counterfactuals = instance + delta_instance
86+
return candidate_counterfactuals, dist
87+
88+
def _counterfactual_search(
89+
self, step: int, factual: torch.Tensor, cat_features_indices: List[list[int]]
90+
) -> pd.DataFrame:
91+
device = "cuda" if torch.cuda.is_available() else "cpu"
92+
93+
# init step size for growing the sphere
94+
low = 0
95+
high = step
96+
# counter
97+
count = 0
98+
counter_step = 1
99+
100+
torch_fact = torch.from_numpy(factual).to(device)
101+
102+
# get predicted label of instance
103+
instance_label = np.argmax(
104+
self._model.predict_proba(torch_fact.float()).cpu().detach().numpy(),
105+
axis=1,
106+
)
107+
108+
# vectorize z
109+
z = self._vae.encode(
110+
torch_fact[:, self._vae.mutable_mask].float()
111+
)[0]
112+
# add the immutable features to the latents
113+
z = torch.cat([z, torch_fact[:, ~self._vae.mutable_mask]], dim=-1)
114+
z = z.cpu().detach().numpy()
115+
z_rep = np.repeat(z.reshape(1, -1), self._n_search_samples, axis=0)
116+
117+
# make copy such that we later easily combine the immutables and the reconstructed mutables
118+
fact_rep = torch_fact.reshape(1, -1).repeat_interleave(
119+
self._n_search_samples, dim=0
120+
)
121+
122+
candidate_dist: List = []
123+
x_ce: Union[np.ndarray, torch.Tensor] = np.array([])
124+
while count <= self._max_iter or len(candidate_dist) <= 0:
125+
count = count + counter_step
126+
if count > self._max_iter:
127+
logging.debug("No counterfactual example found")
128+
return x_ce[0]
129+
130+
# STEP 1 -- SAMPLE POINTS on hyper sphere around instance
131+
latent_neighbourhood, _ = self._hyper_sphere_coordindates(z_rep, high, low)
132+
torch_latent_neighbourhood = (
133+
torch.from_numpy(latent_neighbourhood).to(device).float()
134+
)
135+
x_ce = self._vae.decode(torch_latent_neighbourhood)
136+
137+
# add the immutable features to the reconstruction
138+
temp = fact_rep.clone()
139+
temp[:, self._vae.mutable_mask] = x_ce.float()
140+
x_ce = temp
141+
142+
x_ce = reconstruct_encoding_constraints(
143+
x_ce, cat_features_indices
144+
)
145+
x_ce = x_ce.detach().cpu().numpy()
146+
x_ce = x_ce.clip(0, 1) if self._clamp else x_ce
147+
148+
# STEP 2 -- COMPUTE l1 & l2 norms
149+
if self._p_norm == 1:
150+
distances = np.abs((x_ce - torch_fact.cpu().detach().numpy())).sum(
151+
axis=1
152+
)
153+
elif self._p_norm == 2:
154+
distances = LA.norm(x_ce - torch_fact.cpu().detach().numpy(), axis=1)
155+
else:
156+
raise ValueError("Possible values for p_norm are 1 or 2")
157+
158+
# counterfactual labels
159+
y_candidate = np.argmax(
160+
self._model.predict_proba(torch.from_numpy(x_ce).float())
161+
.cpu()
162+
.detach()
163+
.numpy(),
164+
axis=1,
165+
)
166+
indices = np.where(y_candidate != instance_label)
167+
candidate_counterfactuals = x_ce[indices]
168+
candidate_dist = distances[indices]
169+
# no candidate found & push search range outside
170+
if len(candidate_dist) == 0:
171+
low = high
172+
high = low + step
173+
elif len(candidate_dist) > 0:
174+
# certain candidates generated
175+
min_index = np.argmin(candidate_dist)
176+
logging.debug("Counterfactual example found")
177+
return candidate_counterfactuals[min_index]
178+
179+
def get_counterfactuals(self, factuals: pd.DataFrame) -> pd.DataFrame:
180+
factuals = factuals[self._feature_order] # ensure the feature ordering is correct for the model input
181+
182+
# pay attention to categorical features
183+
encoded_feature_names = self._data.get_categorical_features(expanded=True)
184+
185+
cat_features_indices = []
186+
for features in encoded_feature_names:
187+
# Find the indices of these encoded features in the processed dataframe
188+
indices = [factuals.columns.get_loc(feat) for feat in features]
189+
cat_features_indices.append(indices)
190+
191+
df_cfs = factuals.apply(
192+
lambda x: self._counterfactual_search(
193+
self._step, x.reshape((1, -1)), cat_features_indices
194+
),
195+
raw=True,
196+
axis=1,
197+
)
198+
199+
df_cfs = check_counterfactuals(self._model, self._data, df_cfs, factuals.index)
200+
return df_cfs
201+

method/catalog/REVISE/method.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from data.data_object import DataObject
99
from evaluation.utils import check_counterfactuals
10-
from evaluation.utils import check_counterfactuals
1110
from experiment_utils import deep_merge, reconstruct_encoding_constraints
1211
from method.method_factory import register_method
1312
from method.method_object import MethodObject
@@ -60,7 +59,7 @@ def __init__(self, data: DataObject, model: ModelObject, vae = None, config_over
6059
)
6160
if vae_params['train']:
6261
self.vae.fit(
63-
xtrain=self._model.get_train_data()[0],
62+
xtrain=self._model.get_train_data()[0][self._feature_order],
6463
lambda_reg=vae_params['lambda_reg'],
6564
epochs=vae_params['epochs'],
6665
lr=vae_params['lr'],

0 commit comments

Comments
 (0)