File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11"""Abstract interface for models used in the hyperparameter tuning framework."""
22
33from abc import ABC , abstractmethod
4- from typing import Dict , Any , Literal
4+ from typing import Dict , Any , Literal , overload
5+
6+ from models .cnn import CNNModel
7+ from models .decision_tree import DecisionTreeModel
8+ from models .knn import KNNModel
59
610from .ParamSpace import ParamSpace
711
@@ -38,12 +42,19 @@ def get_param_space(self) -> Dict[str, ParamSpace]:
3842 raise NotImplementedError
3943
4044
41- def get_model_by_name (model_name : Literal ["dt" , "knn" , "cnn" ]) -> BaseModel :
42- """Factory function to get model by name."""
43- from models .decision_tree import DecisionTreeModel
44- from models .knn import KNNModel
45- from models .cnn import CNNModel
45+ @overload
46+ def get_model_by_name (model_name : Literal ["dt" ]) -> DecisionTreeModel :
47+ ...
48+
49+ @overload
50+ def get_model_by_name (model_name : Literal ["knn" ]) -> KNNModel :
51+ ...
52+
53+ @overload
54+ def get_model_by_name (model_name : Literal ["cnn" ]) -> CNNModel :
55+ ...
4656
57+ def get_model_by_name (model_name : Literal ["dt" , "knn" , "cnn" ]) -> KNNModel | DecisionTreeModel | CNNModel :
4758 models = {
4859 "dt" : DecisionTreeModel ,
4960 "knn" : KNNModel ,
You can’t perform that action at this time.
0 commit comments