Skip to content

Commit 11194ba

Browse files
committed
Refactor parameter handling for consistency
1 parent 77120e2 commit 11194ba

3 files changed

Lines changed: 11 additions & 12 deletions

File tree

models/cnn.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,13 @@ def __init__(self, num_classes: int = 10) -> None:
9090
self._input_channels = 1 # grayscale CIFAR-10
9191

9292
def create_model(self, **params) -> None:
93-
kernel_size = params.get("kernel_size", 3)
94-
stride = params.get("stride", 1)
95-
self.params.update(
96-
{
97-
"kernel_size": kernel_size,
98-
"stride": stride,
99-
}
100-
)
93+
# Store all parameters passed in
94+
self.params.update(params)
95+
96+
# Extract architecture-specific parameters for Backbone creation
97+
kernel_size = self.params.get("kernel_size", 3)
98+
stride = self.params.get("stride", 1)
99+
101100
self.network = Backbone(
102101
in_channels=self._input_channels,
103102
num_classes=self.num_classes,

models/decision_tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def __init__(self, **kwargs: Any) -> None:
1717

1818
def create_model(self, **params: Any) -> None:
1919
"""Create the underlying sklearn estimator."""
20-
configuration = {**self.params, **params}
21-
self.estimator = DecisionTreeClassifier(**configuration)
20+
self.params.update(params)
21+
self.estimator = DecisionTreeClassifier(**self.params)
2222

2323
def train(self, X_train, y_train) -> DecisionTreeClassifier:
2424
if self.estimator is None:

models/knn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def __init__(self, **kwargs: Any) -> None:
1616
self.estimator: KNeighborsClassifier | None = None
1717

1818
def create_model(self, **params: Any) -> None:
19-
configuration = {**self.params, **params}
20-
self.estimator = KNeighborsClassifier(**configuration)
19+
self.params.update(params)
20+
self.estimator = KNeighborsClassifier(**self.params)
2121

2222
def train(self, X_train, y_train) -> KNeighborsClassifier:
2323
if self.estimator is None:

0 commit comments

Comments
 (0)