Skip to content

Commit 88cbada

Browse files
committed
add train test example
1 parent f69a8e7 commit 88cbada

1 file changed

Lines changed: 118 additions & 0 deletions

File tree

examples/ex_mnist_train_test.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from sklearn.metrics import balanced_accuracy_score
2+
3+
4+
from eisp.ensemble import Ensemble
5+
from eisp.proxy_tasks import FeatureVectors
6+
from eisp.visualization import (
7+
plot_confusion_matrix,
8+
plot_feature_importance,
9+
)
10+
import numpy as np
11+
12+
import torchvision
13+
from torch.utils.data import DataLoader
14+
15+
transform = torchvision.transforms.Compose(
16+
[
17+
torchvision.transforms.ToTensor(),
18+
torchvision.transforms.Normalize(
19+
(0.1307,), (0.3081,)
20+
), # Mean and standard deviation for MNIST
21+
]
22+
)
23+
24+
train_dataset = torchvision.datasets.MNIST(
25+
root="./data", train=True, download=True, transform=transform
26+
)
27+
28+
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
29+
30+
# Define simple feature extraction functions
31+
feature_names = ["image_itself", "image_mean", "image_std"]
32+
33+
34+
def image_itself(x):
35+
return x.view(x.size(0), -1).numpy()
36+
37+
38+
def image_mean(x):
39+
return x.view(x.size(0), -1).mean(dim=1).numpy().reshape(-1, 1)
40+
41+
42+
def image_std(x):
43+
return x.view(x.size(0), -1).std(dim=1).numpy().reshape(-1, 1)
44+
45+
46+
feature_functions = [image_itself, image_mean, image_std]
47+
48+
# Extract features for training set
49+
train_feature_path = "./data/mnist_train_features"
50+
features: FeatureVectors = FeatureVectors.extract(
51+
train_loader,
52+
feature_functions,
53+
feature_names,
54+
store_path=train_feature_path,
55+
)
56+
57+
58+
labels = []
59+
for _, target in train_loader:
60+
labels.append(target.numpy())
61+
labels = np.concatenate(labels, axis=0)
62+
# print features and labels shape
63+
labels = np.array(labels)
64+
for name in feature_names:
65+
print(f"Features shape: {features.get_all_features()[name].shape}")
66+
print(f"Labels shape: {labels.shape}")
67+
68+
train_features, test_features, train_indices, test_indices = features.train_test_split(
69+
test_size=0.2, random_state=42
70+
)
71+
72+
train_labels = labels[train_indices]
73+
test_labels = labels[test_indices]
74+
75+
# Initialize and train ensemble
76+
ensemble_model = Ensemble(train_features, train_labels)
77+
ensemble_model.train(
78+
model_type="xgboost",
79+
optimization_trials=5,
80+
optimization_direction="maximize",
81+
metric_function=lambda y_true, y_pred: balanced_accuracy_score(
82+
y_true, np.argmax(y_pred, axis=1)
83+
),
84+
should_extract_shap=True,
85+
)
86+
87+
shap_values = ensemble_model.shap
88+
shap_aggregated = ensemble_model.shap_aggregated
89+
90+
# Plot feature importance
91+
feature_importance_save_path = "./data/mnist_vis/feature_importance.png"
92+
plot_feature_importance(
93+
shap_aggregated,
94+
save_path=feature_importance_save_path,
95+
)
96+
print(f"Feature importance plot saved to {feature_importance_save_path}")
97+
98+
print({k: v.shape for k, v in shap_values.items()})
99+
print({k: v for k, v in shap_aggregated.items()})
100+
101+
print("Ensemble training on MNIST completed successfully.")
102+
print(f"Val metric: {ensemble_model.val_metric}")
103+
104+
105+
confusion_matrix_save_path = "./data/mnist_vis/confusion_matrix.png"
106+
plot_confusion_matrix(
107+
true_labels=ensemble_model.true_labels,
108+
pred_labels=np.argmax(ensemble_model.pred_labels, axis=1),
109+
class_names=[str(i) for i in range(10)],
110+
save_path=confusion_matrix_save_path,
111+
)
112+
print(f"Confusion matrix plot saved to {confusion_matrix_save_path}")
113+
114+
all_test_features = np.concatenate(
115+
[test_features.get_all_features()[name] for name in feature_names], axis=1
116+
)
117+
118+
ensemble_model.test_xgboost(all_test_features, test_labels)

0 commit comments

Comments
 (0)