Skip to content
This repository was archived by the owner on Dec 15, 2025. It is now read-only.

Commit c31c908

Browse files
committed
Change name GradientBoostingTree to GradientBoostedTree
1 parent fc3aed3 commit c31c908

5 files changed

Lines changed: 229 additions & 15 deletions

File tree

bin/workloads/ml/gbt/prepare/prepare.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ show_bannar start
2626
rmr_hdfs $INPUT_HDFS || true
2727
START_TIME=`timestamp`
2828

29-
run_spark_job com.intel.hibench.sparkbench.ml.GradientBoostingTreeDataGenerator $INPUT_HDFS $NUM_EXAMPLES_GBT $NUM_FEATURES_GBT
29+
run_spark_job com.intel.hibench.sparkbench.ml.GradientBoostedTreeDataGenerator $INPUT_HDFS $NUM_EXAMPLES_GBT $NUM_FEATURES_GBT
3030

3131
END_TIME=`timestamp`
3232

bin/workloads/ml/gbt/spark/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ rmr_hdfs $OUTPUT_HDFS || true
2626

2727
SIZE=`dir_size $INPUT_HDFS`
2828
START_TIME=`timestamp`
29-
run_spark_job com.intel.hibench.sparkbench.ml.GradientBoostingTree ${INPUT_HDFS} ${NUM_ITERATIONS_GBT}
29+
run_spark_job com.intel.hibench.sparkbench.ml.GradientBoostedTree ${INPUT_HDFS} ${NUM_ITERATIONS_GBT}
3030
END_TIME=`timestamp`
3131

3232
gen_report ${START_TIME} ${END_TIME} ${SIZE}

conf/workloads/ml/gbt.conf

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
1-
hibench.gbt.tiny.examples 10
2-
hibench.gbt.tiny.features 100
3-
hibench.gbt.small.examples 100
4-
hibench.gbt.small.features 500
5-
hibench.gbt.large.examples 1000
6-
hibench.gbt.large.features 2000
7-
hibench.gbt.huge.examples 1000
8-
hibench.gbt.huge.features 4000
9-
hibench.gbt.gigantic.examples 1000
10-
hibench.gbt.gigantic.features 8000
11-
hibench.gbt.bigdata.examples 1000
12-
hibench.gbt.bigdata.features 12000
1+
hibench.gbt.tiny.examples 10
2+
hibench.gbt.tiny.features 100
3+
hibench.gbt.small.examples 100
4+
hibench.gbt.small.features 500
5+
hibench.gbt.large.examples 1000
6+
hibench.gbt.large.features 2000
7+
hibench.gbt.huge.examples 1000
8+
hibench.gbt.huge.features 4000
9+
hibench.gbt.gigantic.examples 1000
10+
hibench.gbt.gigantic.features 8000
11+
hibench.gbt.bigdata.examples 1000
12+
hibench.gbt.bigdata.features 12000
1313

1414

1515
hibench.gbt.examples ${hibench.gbt.${hibench.scale.profile}.examples}
1616
hibench.gbt.features ${hibench.gbt.${hibench.scale.profile}.features}
1717
hibench.gbt.partitions ${hibench.default.map.parallelism}
18-
hibench.gbt.numIterations 100
18+
19+
hibench.gbt.numClasses 2
20+
hibench.gbt.maxDepth 30
21+
hibench.gbt.maxBins 32
22+
hibench.gbt.numIterations 20
23+
hibench.gbt.learningRate 0.1
1924

2025
hibench.workload.input ${hibench.hdfs.data.dir}/GBT/Input
2126
hibench.workload.output ${hibench.hdfs.data.dir}/GBT/Output
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package com.intel.hibench.sparkbench.ml
19+
20+
import org.apache.spark.{SparkConf, SparkContext}
21+
import org.apache.spark.mllib.tree.GradientBoostedTrees
22+
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
23+
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
24+
import org.apache.spark.rdd.RDD
25+
import org.apache.spark.mllib.regression.LabeledPoint
26+
27+
import scopt.OptionParser
28+
29+
object GradientBoostedTree {
30+
31+
case class Params(
32+
numClasses: Int = 2,
33+
maxDepth: Int = 30,
34+
maxBins: Int = 32,
35+
numIterations: Int = 20,
36+
learningRate: Double = 0.1,
37+
dataPath: String = null
38+
)
39+
40+
def main(args: Array[String]): Unit = {
41+
val defaultParams = Params()
42+
43+
val parser = new OptionParser[Params]("GBT"){
44+
head("GBT: an example of Gradient Boosted Tree for classification")
45+
opt[Int]("numClasses")
46+
.text(s"numClasses, default: ${defaultParams.numClasses}")
47+
.action((x,c) => c.copy(numClasses = x))
48+
opt[Int]("maxDepth")
49+
.text(s"maxDepth, default: ${defaultParams.maxDepth}")
50+
.action((x,c) => c.copy(maxDepth = x))
51+
opt[Int]("maxBins")
52+
.text(s"maxBins, default: ${defaultParams.maxBins}")
53+
.action((x,c) => c.copy(maxBins = x))
54+
opt[Int]("numIterations")
55+
.text(s"numIterations, default: ${defaultParams.numIterations}")
56+
.action((x,c) => c.copy(numIterations = x))
57+
opt[Double]("learningRate")
58+
.text(s"learningRate, default: ${defaultParams.learningRate}")
59+
.action((x,c) => c.copy(learningRate = x))
60+
arg[String]("<dataPath>")
61+
.required()
62+
.text("data path for Gradient Boosted Tree")
63+
.action((xc) => c.copy(dataPath = x))
64+
}
65+
parser.parse(args, defaultParams) match {
66+
case some(params) => run(params)
67+
case _ => sys.exit(1)
68+
}
69+
}
70+
71+
def run(params: Params): Unit = {
72+
val conf = new SparkConf().setAppName(s"Gradient Boosted Tree with $params")
73+
val sc = new SparkContext(conf)
74+
75+
val dataPath = params.dataPath
76+
val numClasses = params.numClasses
77+
val maxDepth = params.maxDepth
78+
val maxBins = params.maxBins
79+
val numIterations = params.numIterations
80+
val learningRate = params.learningRate
81+
82+
// Load data file.
83+
val data: RDD[LabeledPoint] = sc.objectFile(dataPath)
84+
85+
// Split the data into training and test sets (30% held out for testing)
86+
val splits = data.randomSplit(Array(0.7, 0.3))
87+
val (trainingData, testData) = (splits(0), splits(1))
88+
89+
// Train a GradientBoostedTrees model.
90+
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
91+
boostingStrategy.numIterations = numIterations
92+
boostingStrategy.learningRate = learningRate
93+
boostingStrategy.treeStrategy.numClasses = numClasses
94+
boostingStrategy.treeStrategy.maxDepth = maxDepth
95+
boostingStrategy.treeStrategy.maxBins = maxBins
96+
// Empty categoricalFeaturesInfo indicates all features are continuous.
97+
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
98+
99+
val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
100+
101+
// Evaluate model on test instances and compute test error
102+
val labelAndPreds = testData.map { point =>
103+
val prediction = model.predict(point.features)
104+
(point.label, prediction)
105+
}
106+
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
107+
println("Test Error = " + testErr)
108+
109+
sc.stop()
110+
}
111+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package com.intel.hibench.sparkbench.ml
19+
20+
import com.intel.hibench.sparkbench.common.IOCommon
21+
22+
import scala.util.Random
23+
24+
import org.apache.spark.{SparkConf, SparkContext}
25+
import org.apache.spark.annotation.{DeveloperApi, Since}
26+
import org.apache.spark.mllib.linalg.Vectors
27+
import org.apache.spark.mllib.regression.LabeledPoint
28+
import org.apache.spark.rdd.RDD
29+
30+
/**
31+
* :: DeveloperApi ::
32+
* Generate test data for Gradient Boosting Tree. This class chooses positive labels
33+
* with probability `probOne` and scales features for positive examples by `eps`.
34+
*/
35+
object GradientBoostingTreeDataGenerator {
36+
37+
/**
38+
* Generate an RDD containing test data for Gradient Boosting Tree.
39+
*
40+
* @param sc SparkContext to use for creating the RDD.
41+
* @param nexamples Number of examples that will be contained in the RDD.
42+
* @param nfeatures Number of features to generate for each example.
43+
* @param eps Epsilon factor by which positive examples are scaled.
44+
* @param nparts Number of partitions of the generated RDD. Default value is 2.
45+
* @param probOne Probability that a label is 1 (and not 0). Default value is 0.5.
46+
*/
47+
def generateGBTRDD(
48+
sc: SparkContext,
49+
nexamples: Int,
50+
nfeatures: Int,
51+
eps: Double,
52+
nparts: Int = 2,
53+
probOne: Double = 0.5): RDD[LabeledPoint] = {
54+
val data = sc.parallelize(0 until nexamples, nparts).map { idx =>
55+
val rnd = new Random(42 + idx)
56+
57+
val y = if (idx % 2 == 0) 0.0 else 1.0
58+
val x = Array.fill[Double](nfeatures) {
59+
rnd.nextGaussian() + (y * eps)
60+
}
61+
LabeledPoint(y, Vectors.dense(x))
62+
}
63+
data
64+
}
65+
66+
def main(args: Array[String]) {
67+
val conf = new SparkConf().setAppName("GradientBoostingTreeDataGenerator")
68+
val sc = new SparkContext(conf)
69+
70+
var outputPath = ""
71+
var numExamples: Int = 200000
72+
var numFeatures: Int = 20
73+
val parallel = sc.getConf.getInt("spark.default.parallelism", sc.defaultParallelism)
74+
val numPartitions = IOCommon.getProperty("hibench.default.shuffle.parallelism")
75+
.getOrElse((parallel / 2).toString).toInt
76+
val eps = 0.3
77+
78+
if (args.length == 3) {
79+
outputPath = args(0)
80+
numExamples = args(1).toInt
81+
numFeatures = args(2).toInt
82+
println(s"Output Path: $outputPath")
83+
println(s"Num of Examples: $numExamples")
84+
println(s"Num of Features: $numFeatures")
85+
} else {
86+
System.err.println(
87+
s"Usage: $GradientBoostingTreeDataGenerator <OUTPUT_PATH> <NUM_EXAMPLES> <NUM_FEATURES>"
88+
)
89+
System.exit(1)
90+
}
91+
92+
val data = generateGBTRDD(sc, numExamples, numFeatures, eps, numPartitions)
93+
94+
data.saveAsObjectFile(outputPath)
95+
96+
sc.stop()
97+
}
98+
}

0 commit comments

Comments
 (0)