Skip to content
This repository was archived by the owner on Dec 15, 2025. It is now read-only.

Commit af6bd6c

Browse files
committed
change the form of arguments of GBT to OptionParser
1 parent c31c908 commit af6bd6c

4 files changed

Lines changed: 8 additions & 4 deletions

File tree

bin/functions/hibench_prop_env_mapping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@
115115
# For Gradient Boosting Tree
116116
NUM_EXAMPLES_GBT="hibench.gbt.examples",
117117
NUM_FEATURES_GBT="hibench.gbt.features",
118+
NUM_CLASSES_GBT="hibench.gbt.numClasses",
119+
MAX_DEPTH_GBT="hibench.gbt.maxDepth",
120+
MAX_BINS_GBT="hibench.gbt.maxBins",
118121
NUM_ITERATIONS_GBT="hibench.gbt.numIterations",
122+
LEARNING_RATE_GBT="hibench.gbt.learningRate",
119123
# For Random Forest
120124
NUM_EXAMPLES_RF="hibench.rf.examples",
121125
NUM_FEATURES_RF="hibench.rf.features",

bin/workloads/ml/gbt/spark/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ rmr_hdfs $OUTPUT_HDFS || true
2626

2727
SIZE=`dir_size $INPUT_HDFS`
2828
START_TIME=`timestamp`
29-
run_spark_job com.intel.hibench.sparkbench.ml.GradientBoostedTree ${INPUT_HDFS} ${NUM_ITERATIONS_GBT}
29+
run_spark_job com.intel.hibench.sparkbench.ml.GradientBoostedTree --numClasses $NUM_CLASSES_GBT --maxDepth $MAX_DEPTH_GBT --maxBins $MAX_BINS_GBT --numIterations $NUM_ITERATIONS_GBT --learningRate $LEARNING_RATE_GBT $INPUT_HDFS
3030
END_TIME=`timestamp`
3131

3232
gen_report ${START_TIME} ${END_TIME} ${SIZE}

sparkbench/ml/src/main/scala/com/intel/sparkbench/ml/GradientBoostedTree.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ object GradientBoostedTree {
6060
arg[String]("<dataPath>")
6161
.required()
6262
.text("data path for Gradient Boosted Tree")
63-
.action((xc) => c.copy(dataPath = x))
63+
.action((x,c) => c.copy(dataPath = x))
6464
}
6565
parser.parse(args, defaultParams) match {
66-
case some(params) => run(params)
66+
case Some(params) => run(params)
6767
case _ => sys.exit(1)
6868
}
6969
}

sparkbench/ml/src/main/scala/com/intel/sparkbench/ml/GradientBoostedTreeDataGenerator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.rdd.RDD
3232
* Generate test data for Gradient Boosting Tree. This class chooses positive labels
3333
* with probability `probOne` and scales features for positive examples by `eps`.
3434
*/
35-
object GradientBoostingTreeDataGenerator {
35+
object GradientBoostedTreeDataGenerator {
3636

3737
/**
3838
* Generate an RDD containing test data for Gradient Boosting Tree.

0 commit comments

Comments
 (0)