1515 * limitations under the License.
1616 */
1717
18- // scalastyle:off println
1918package com .intel .hibench .sparkbench .ml
2019
2120import org .apache .spark .{SparkConf , SparkContext }
22- // $example on$
2321import org .apache .spark .mllib .classification .{SVMModel , SVMWithSGD }
2422import org .apache .spark .mllib .evaluation .BinaryClassificationMetrics
25- import org .apache .spark .mllib .util .MLUtils
2623import org .apache .spark .rdd .RDD
2724import org .apache .spark .mllib .regression .LabeledPoint
28- // $example off$
25+
26+ import scopt .OptionParser
2927
3028object SVMWithSGDExample {
3129
30+ case class Params (
31+ numIterations : Int = 100 ,
32+ stepSize : Double = 1.0 ,
33+ regParam : Double = 0.01 ,
34+ dataPath : String = null
35+ )
36+
3237 def main (args : Array [String ]): Unit = {
33- var inputPath = " "
34- if (args.length == 1 ) {
35- inputPath = args(0 )
36- }
38+ val defaultParams = Params ()
39+
40+ val parser = new OptionParser [Params ](" SVM" ) {
41+ head(" SVM: an example of SVM for classification." )
42+ opt[Int ](" numIterations" )
43+ .text(s " numIterations, default: ${defaultParams.numIterations}" )
44+ .action((x,c) => c.copy(numIterations = x))
45+ opt[Double ](" stepSize" )
46+ .text(s " stepSize, default: ${defaultParams.stepSize}" )
47+ .action((x,c) => c.copy(stepSize = x))
48+ opt[Double ](" regParam" )
49+ .text(s " regParam, default: ${defaultParams.regParam}" )
50+ .action((x,c) => c.copy(regParam = x))
51+ arg[String ](" <dataPath>" )
52+ .required()
53+ .text(" data path of SVM" )
54+ .action((x, c) => c.copy(dataPath = x))
55+ }
56+ parser.parse(args, defaultParams) match {
57+ case Some (params) => run(params)
58+ case _ => sys.exit(1 )
59+ }
60+ }
3761
38- val conf = new SparkConf ().setAppName(" SVMWithSGDExample" )
62+ def run (params : Params ): Unit = {
63+
64+ val conf = new SparkConf ().setAppName(s " SVM with $params" )
3965 val sc = new SparkContext (conf)
4066
41- // $example on$
42- val data : RDD [LabeledPoint ] = sc.objectFile(inputPath)
67+ val dataPath = params.dataPath
68+ val numIterations = params.numIterations
69+ val stepSize = params.stepSize
70+ val regParam = params.regParam
71+
72+ val data : RDD [LabeledPoint ] = sc.objectFile(dataPath)
4373
4474 // Split data into training (60%) and test (40%).
4575 val splits = data.randomSplit(Array (0.6 , 0.4 ), seed = 11L )
4676 val training = splits(0 ).cache()
4777 val test = splits(1 )
4878
4979 // Run training algorithm to build the model
50- val numIterations = 100
51- val model = SVMWithSGD .train(training, numIterations)
80+ val model = SVMWithSGD .train(training, numIterations, stepSize, regParam)
5281
5382 // Clear the default threshold.
5483 model.clearThreshold()
@@ -65,8 +94,6 @@ object SVMWithSGDExample {
6594
6695 println(" Area under ROC = " + auROC)
6796
68- // Save and load model
6997 sc.stop()
7098 }
7199}
72- // scalastyle:on println
0 commit comments