Skip to content

Commit 298648a

Browse files
committed
save feature names when saving model
1 parent 9bf5fc6 commit 298648a

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

cpa/classifier.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,7 @@ def SaveModel(self, evt=None):
10761076
filename = saveDialog.GetPath()
10771077
self.defaultModelFileName = os.path.split(filename)[1]
10781078
bin_labels = [bin.label for bin in self.classBins]
1079+
self.algorithm._set_features(self.trainingset.colnames)
10791080
self.algorithm.SaveModel(filename, bin_labels)
10801081
self.PostMessage('Classifier model succesfully saved.')
10811082

cpa/generalclassifier.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,14 @@ def __init__(self, classifier = "discriminant_analysis.LinearDiscriminantAnalysi
2222
self.trained = False
2323
self.env = env # Env is Classifier in Legacy Code -- maybe renaming ?
2424
self.name = self.name()
25+
self.features = []
2526

2627
logging.info('Initialized New Classifier: ' + self.name)
2728

29+
# Set features
30+
def _set_features(self, features):
31+
self.features = features
32+
2833
# Return name
2934
def name(self):
3035
return self.classifier.__class__.__name__
@@ -78,7 +83,7 @@ def IsTrained(self):
7883
def LoadModel(self, model_filename):
7984

8085
try:
81-
self.classifier, self.bin_labels, self.name = joblib.load(model_filename)
86+
self.classifier, self.bin_labels, self.name, self.features = joblib.load(model_filename)
8287
except:
8388
self.classifier = None
8489
self.bin_labels = None
@@ -127,7 +132,7 @@ def PredictProba(self, test_values):
127132
logging.info("Selected algorithm doesn't provide probabilities")
128133

129134
def SaveModel(self, model_filename, bin_labels):
130-
joblib.dump((self.classifier, bin_labels, self.name), model_filename, compress=1)
135+
joblib.dump((self.classifier, bin_labels, self.name, self.features), model_filename, compress=1)
131136

132137
def ShowModel(self):#SKLEARN TODO
133138
'''

0 commit comments

Comments
 (0)