Skip to content

Commit 2ed19c5

Browse files
committed
Add type overloads for model factory
1 parent f66cdd8 commit 2ed19c5

1 file changed

Lines changed: 17 additions & 6 deletions

File tree

models/base.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Abstract interface for models used in the hyperparameter tuning framework."""
22

33
from abc import ABC, abstractmethod
4-
from typing import Dict, Any, Literal
4+
from typing import Dict, Any, Literal, overload
5+
6+
from models.cnn import CNNModel
7+
from models.decision_tree import DecisionTreeModel
8+
from models.knn import KNNModel
59

610
from .ParamSpace import ParamSpace
711

@@ -38,12 +42,19 @@ def get_param_space(self) -> Dict[str, ParamSpace]:
3842
raise NotImplementedError
3943

4044

41-
def get_model_by_name(model_name: Literal["dt", "knn", "cnn"]) -> BaseModel:
42-
"""Factory function to get model by name."""
43-
from models.decision_tree import DecisionTreeModel
44-
from models.knn import KNNModel
45-
from models.cnn import CNNModel
45+
@overload
46+
def get_model_by_name(model_name: Literal["dt"]) -> DecisionTreeModel:
47+
...
48+
49+
@overload
50+
def get_model_by_name(model_name: Literal["knn"]) -> KNNModel:
51+
...
52+
53+
@overload
54+
def get_model_by_name(model_name: Literal["cnn"]) -> CNNModel:
55+
...
4656

57+
def get_model_by_name(model_name: Literal["dt", "knn", "cnn"]) -> KNNModel | DecisionTreeModel | CNNModel:
4758
models = {
4859
"dt": DecisionTreeModel,
4960
"knn": KNNModel,

0 commit comments

Comments
 (0)