@@ -350,7 +350,14 @@ def __init__(self, properties=None, parent=None, id=ID_CLASSIFIER, **kwargs):
350350 'Load Default Training Set?' , wx .YES_NO | wx .ICON_QUESTION )
351351 response = dlg .ShowModal ()
352352 if response == wx .ID_YES :
353- self .LoadTrainingSet (p .training_set )
353+ name , file_extension = os .path .splitext (p .training_set )
354+ if '.txt' == file_extension :
355+ self .LoadTrainingSet (p .training_set )
356+ elif '.csv' == file_extension :
357+ self .LoadTrainingSetCSV (p .training_set )
358+ else :
359+ logging .error ("Couldn't load the file! Make sure it is .txt or .csv" )
360+ #self.LoadTrainingSet(p.training_set)
354361
355362 self .AutoSave () # Autosave try out
356363
@@ -474,10 +481,6 @@ def CreateMenus(self):
474481 help = 'Loads objects and classes specified in a training set file.' )
475482 self .saveTSMenuItem = self .fileMenu .Append (- 1 , text = 'Save Training Set\t Ctrl+S' ,
476483 help = 'Save your training set to file so you can reload these classified cells again.' )
477- self .loadFullTSMenuItem = self .fileMenu .Append (- 1 , text = 'Load Training Set (CSV)' ,
478- help = 'Loads objects and classes specified in a training set file.' )
479- self .saveFullTSMenuItem = self .fileMenu .Append (- 1 , text = 'Save Training Set (CSV)' ,
480- help = 'Save your training data as CSV' )
481484 self .fileMenu .AppendSeparator ()
482485 # JEN - Start Add
483486 self .loadModelMenuItem = self .fileMenu .Append (- 1 , text = 'Load Classifier Model' , help = 'Loads a classifier model specified in a text file' )
@@ -552,9 +555,7 @@ def CreateMenus(self):
552555
553556 # Bind events to different menu items
554557 self .Bind (wx .EVT_MENU , self .OnLoadTrainingSet , self .loadTSMenuItem )
555- self .Bind (wx .EVT_MENU , self .OnLoadFullTrainingSet , self .loadFullTSMenuItem )
556558 self .Bind (wx .EVT_MENU , self .OnSaveTrainingSet , self .saveTSMenuItem )
557- self .Bind (wx .EVT_MENU , self .OnSaveFullTrainingSet , self .saveFullTSMenuItem )
558559 self .Bind (wx .EVT_MENU , self .OnLoadModel , self .loadModelMenuItem ) # JEN - Added
559560 self .Bind (wx .EVT_MENU , self .SaveModel , self .saveModelMenuItem ) # JEN - Added
560561 self .Bind (wx .EVT_MENU , self .OnShowImageControls , imageControlsMenuItem )
@@ -1076,23 +1077,29 @@ def OnLoadTrainingSet(self, evt):
10761077 '''
10771078 dlg = wx .FileDialog (self , "Select the file containing your classifier training set." ,
10781079 defaultDir = os .getcwd (),
1079- wildcard = 'Text files(*.txt)|*.txt|All files(*.* )|*.* ' ,
1080+ wildcard = 'Text files(*.txt)|*.txt|CSV files(*.csv )|*.csv ' ,
10801081 style = wx .OPEN | wx .FD_CHANGE_DIR )
10811082 if dlg .ShowModal () == wx .ID_OK :
10821083 filename = dlg .GetPath ()
1083- self .LoadTrainingSet (filename )
1084-
1085- def OnLoadFullTrainingSet (self , evt ):
1086- '''
1087- Present user with file select dialog, then load selected training set.
1088- '''
1089- dlg = wx .FileDialog (self , "Select the file containing your classifier training set." ,
1090- defaultDir = os .getcwd (),
1091- wildcard = 'Text files(*.csv)|*.csv|All files(*.*)|*.*' ,
1092- style = wx .OPEN | wx .FD_CHANGE_DIR )
1093- if dlg .ShowModal () == wx .ID_OK :
1094- filename = dlg .GetPath ()
1095- self .LoadTrainingSetCSV (filename )
1084+ name , file_extension = os .path .splitext (filename )
1085+ if '.txt' == file_extension :
1086+ self .LoadTrainingSet (filename )
1087+ elif '.csv' == file_extension :
1088+ self .LoadTrainingSetCSV (filename )
1089+ else :
1090+ logging .error ("Couldn't load the file! Make sure it is .txt or .csv" )
1091+
1092+ # def OnLoadFullTrainingSet(self, evt):
1093+ # '''
1094+ # Present user with file select dialog, then load selected training set.
1095+ # '''
1096+ # dlg = wx.FileDialog(self, "Select the file containing your classifier training set.",
1097+ # defaultDir=os.getcwd(),
1098+ # wildcard='Text files(*.csv)|*.csv|All files(*.*)|*.*',
1099+ # style=wx.OPEN | wx.FD_CHANGE_DIR)
1100+ # if dlg.ShowModal() == wx.ID_OK:
1101+ # filename = dlg.GetPath()
1102+ # self.LoadTrainingSetCSV(filename)
10961103
10971104 def LoadTrainingSet (self , filename ):
10981105 '''
@@ -1156,22 +1163,7 @@ def LoadTrainingSetCSV(self, filename):
11561163 def OnSaveTrainingSet (self , evt ):
11571164 self .SaveTrainingSet ()
11581165
1159- def OnSaveFullTrainingSet (self , evt ):
1160- self .SaveFullTrainingSet ()
1161-
11621166 def SaveTrainingSet (self ):
1163- if not self .defaultTSFileName :
1164- self .defaultTSFileName = 'MyTrainingSet.txt'
1165- saveDialog = wx .FileDialog (self , message = "Save as:" , defaultDir = os .getcwd (),
1166- defaultFile = self .defaultTSFileName ,
1167- wildcard = 'Text files (*.txt)|*.txt|All files (*.*)|*.*' ,
1168- style = wx .FD_SAVE | wx .FD_OVERWRITE_PROMPT | wx .FD_CHANGE_DIR )
1169- if saveDialog .ShowModal () == wx .ID_OK :
1170- filename = saveDialog .GetPath ()
1171- self .defaultTSFileName = os .path .split (filename )[1 ]
1172- self .SaveTrainingSetAs (filename )
1173-
1174- def SaveFullTrainingSet (self ):
11751167 if not self .defaultTSFileName :
11761168 self .defaultTSFileName = 'MyTrainingSet.csv'
11771169 saveDialog = wx .FileDialog (self , message = "Save as:" , defaultDir = os .getcwd (),
@@ -1183,17 +1175,6 @@ def SaveFullTrainingSet(self):
11831175 self .defaultTSFileName = os .path .split (filename )[1 ]
11841176 self .SaveTrainingSetAsCSV (filename )
11851177
1186- def SaveTrainingSetAs (self , filename ):
1187- classDict = {}
1188- trainingSet = self .trainingSet # Create Save Copy
1189- try :
1190- self .trainingSet = TrainingSet (p )
1191- self .trainingSet .Create ([bin .label for bin in self .classBins ], [bin .GetObjectKeys () for bin in self .classBins ])
1192- except :
1193- logging .info ("Couldn't update TrainingSet. Using last AutoSave." )
1194- self .trainingSet = trainingSet # Use backup
1195- self .trainingSet .Save (filename )
1196-
11971178 def SaveTrainingSetAsCSV (self , filename ):
11981179 classDict = {}
11991180 trainingSet = self .trainingSet # Create Save Copy
@@ -1253,36 +1234,36 @@ def OnEvaluation(self, evt):
12531234 if item .IsChecked ():
12541235 selectedText = item .GetText ()
12551236
1256- if selectedText == "ROC Curve" :
1257- self .PlotROC ()
1258- elif selectedText == "Learning Curve" :
1259- self .PlotLearningCurveWrapper ()
1260- elif selectedText == "Precision Recall Curve" :
1261- self .PlotPrecisionRecall ()
1262- elif selectedText == "Confusion Matrix" :
1237+ # if selectedText == "ROC Curve":
1238+ # self.PlotROC()
1239+ # elif selectedText == "Learning Curve":
1240+ # self.PlotLearningCurveWrapper()
1241+ # elif selectedText == "Precision Recall Curve":
1242+ # self.PlotPrecisionRecall()
1243+ if selectedText == "Confusion Matrix" :
12631244 self .algorithm .ConfusionMatrix ()
12641245 else :
12651246 self .algorithm .CheckProgress ()
12661247
1267- def PlotLearningCurveWrapper (self ):
1268- from sklearn import cross_validation
1269- model = self .algorithm
1270- clf = model .classifier
1271- X_train = self .trainingSet .values
1272- y_train = self .trainingSet .label_array
1273-
1274- # Training Set is currently sorted. I will shuffle it
1275- y_trans = y_train .reshape (y_train .shape [0 ],1 )
1276- df_values = pd .DataFrame (X_train , columns = self .trainingSet .colnames )
1277- df_class = pd .DataFrame (y_trans , columns = ["Class" ])
1278- df = pd .concat ([df_class , df_values ],axis = 1 )
1279- df = df .reindex (np .random .permutation (df .index ))
1280- X_train = df [self .trainingSet .colnames ].values
1281- y_train = df ["Class" ].values
1282-
1283- #cv = cross_validation.StratifiedKFold(y_train, n_folds=3, shuffle=False)
1284- plot_title = 'Learning Curves ({})' .format (model .name )
1285- self .PlotLearningCurve (clf , plot_title , X_train , y_train , cv = 5 )
1248+ # def PlotLearningCurveWrapper(self):
1249+ # from sklearn import cross_validation
1250+ # model = self.algorithm
1251+ # clf = model.classifier
1252+ # X_train = self.trainingSet.values
1253+ # y_train = self.trainingSet.label_array
1254+
1255+ # # Training Set is currently sorted. I will shuffle it
1256+ # y_trans = y_train.reshape(y_train.shape[0],1)
1257+ # df_values = pd.DataFrame(X_train, columns=self.trainingSet.colnames)
1258+ # df_class = pd.DataFrame(y_trans, columns=["Class"])
1259+ # df = pd.concat([df_class, df_values],axis=1)
1260+ # df = df.reindex(np.random.permutation(df.index))
1261+ # X_train = df[self.trainingSet.colnames].values
1262+ # y_train = df["Class"].values
1263+
1264+ # #cv = cross_validation.StratifiedKFold(y_train, n_folds=3, shuffle=False)
1265+ # plot_title = 'Learning Curves ({})'.format(model.name)
1266+ # self.PlotLearningCurve(clf, plot_title, X_train, y_train, cv=5)
12861267
12871268 from utils import delay
12881269 # Add AutoSave by DD
0 commit comments