-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathlearn_concepts_dataset.py
More file actions
100 lines (86 loc) · 3.14 KB
/
learn_concepts_dataset.py
File metadata and controls
100 lines (86 loc) · 3.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import pickle
import torch
import argparse
import numpy as np
from re import sub
from models import get_model
from concepts import learn_concept_bank
from data import get_concept_loaders
def config():
parser = argparse.ArgumentParser()
parser.add_argument("--backbone-name", default="resnet18_cub", type=str)
parser.add_argument("--dataset-name", default="cub", type=str)
parser.add_argument("--out-dir", required=True, type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--seed", default=1, type=int, help="Random seed")
parser.add_argument(
"--num-workers",
default=4,
type=int,
help="Number of workers in the data loader.",
)
parser.add_argument(
"--batch-size", default=100, type=int, help="Batch size in the concept loader."
)
parser.add_argument(
"--C",
nargs="+",
default=[0.01, 0.1],
type=float,
help="Regularization parameter for SVMs.",
)
parser.add_argument(
"--n-samples",
default=50,
type=int,
help="Number of positive/negative samples used to learn concepts.",
)
return parser.parse_args()
def main():
args = config()
n_samples = args.n_samples
# Bottleneck part of model
backbone, preprocess = get_model(args, args.backbone_name)
backbone = backbone.to(args.device)
backbone = backbone.eval()
concept_libs = {C: {} for C in args.C}
# Get the positive and negative loaders for each concept.
concept_loaders = get_concept_loaders(
args.dataset_name,
preprocess,
n_samples=args.n_samples,
batch_size=args.batch_size,
num_workers=args.num_workers,
seed=args.seed,
)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
for concept_name, loaders in concept_loaders.items():
pos_loader, neg_loader = loaders["pos"], loaders["neg"]
# Get CAV for each concept using positive/negative image split
cav_info = learn_concept_bank(
pos_loader, neg_loader, backbone, n_samples, args.C, device=args.device
)
# Store CAV train acc, val acc, margin info for each regularization parameter and each concept
for C in args.C:
concept_libs[C][concept_name] = cav_info[C]
print(
f"concept_name: {concept_name}, C: {C}, train_acc: {cav_info[C][1]}, val_acc: {cav_info[C][2]}"
)
# This part of the code ensures that there are no colons in the backbone name (as it causes exporting errors):
if ":" in args.backbone_name:
args.backbone_name = sub(":", "", args.backbone_name)
# Save CAV results
for C in concept_libs.keys():
lib_path = os.path.join(
args.out_dir,
f"{args.dataset_name}_{args.backbone_name}_{C}_{args.n_samples}.pkl",
)
with open(lib_path, "wb") as f:
pickle.dump(concept_libs[C], f)
print(f"Saved to {lib_path}!")
total_concepts = len(concept_libs[C].keys())
print(f"File: {lib_path}, Total: {total_concepts} \n")
if __name__ == "__main__":
main()