Skip to content

Commit 77120e2

Browse files
committed
Refactor model hyperparameters
1 parent 240b698 commit 77120e2

3 files changed

Lines changed: 7 additions & 22 deletions

File tree

models/cnn.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ def __init__(
5353
nn.BatchNorm2d(32),
5454
nn.ReLU(inplace=True),
5555
nn.Conv2d(
56-
32, 64, kernel_size=min(kernel_size, 3), stride=stride, padding=1
56+
32, 64, kernel_size=kernel_size, stride=stride, padding=1
5757
),
5858
nn.BatchNorm2d(64),
5959
nn.ReLU(inplace=True),
6060
nn.Conv2d(
6161
64,
6262
128,
63-
kernel_size=min(kernel_size, 3),
64-
stride=max(1, stride - 1),
63+
kernel_size=kernel_size,
64+
stride=stride,
6565
padding=1,
6666
),
6767
nn.BatchNorm2d(128),
@@ -267,12 +267,12 @@ def get_param_space(self) -> Dict[str, ParamSpace]:
267267
"learning_rate": ParamSpace.float_range(
268268
min_val=1e-5, max_val=1e-2, default=3e-4
269269
),
270-
"batch_size": ParamSpace.integer(min_val=16, max_val=128, default=64),
270+
"batch_size": ParamSpace.categorical(choices=[16, 32, 64, 128], default=64),
271271
"weight_decay": ParamSpace.float_range(
272272
min_val=0.0, max_val=0.01, default=1e-3
273273
),
274274
"optimizer": ParamSpace.categorical(
275-
choices=["AdamW", "Adam", "SGD"], default="AdamW"
275+
choices=["AdamW", "SGD"], default="AdamW"
276276
),
277277
}
278278

@@ -286,12 +286,6 @@ def _build_optimizer(
286286
lr=config.learning_rate,
287287
weight_decay=config.weight_decay,
288288
)
289-
case "Adam":
290-
return optim.Adam(
291-
network.parameters(),
292-
lr=config.learning_rate,
293-
weight_decay=config.weight_decay,
294-
)
295289
case "SGD":
296290
return optim.SGD(
297291
network.parameters(),

models/decision_tree.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,5 @@ def get_param_space(self) -> Dict[str, ParamSpace]:
8080
"min_samples_leaf": ParamSpace.integer(min_val=1, max_val=10, default=2),
8181
"criterion": ParamSpace.categorical(
8282
choices=["gini", "entropy"], default="gini"
83-
),
84-
"splitter": ParamSpace.categorical(
85-
choices=["best", "random"], default="best"
86-
),
83+
)
8784
}

models/knn.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,7 @@ def get_param_space(self) -> Dict[str, ParamSpace]:
6868
"weights": ParamSpace.categorical(
6969
choices=["uniform", "distance"], default="uniform"
7070
),
71-
"algorithm": ParamSpace.categorical(
72-
choices=["auto", "ball_tree", "kd_tree", "brute"], default="auto"
73-
),
74-
"p": ParamSpace.categorical(
75-
choices=[1, 2], default=2
76-
), # Manhattan and Euclidean distance
7771
"metric": ParamSpace.categorical(
78-
choices=["minkowski", "chebyshev", "manhattan"], default="minkowski"
72+
choices=["minkowski", "manhattan"], default="minkowski"
7973
),
8074
}

0 commit comments

Comments
 (0)