本文整理汇总了Scala中org.apache.spark.mllib.tree.model.RandomForestModel类的典型用法代码示例。如果您正苦于以下问题:Scala RandomForestModel类的具体用法?Scala RandomForestModel怎么用?Scala RandomForestModel使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
在下文中一共展示了RandomForestModel类的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Scala代码示例。
示例1: MLLibRandomForestModel
//设置package包名称以及导入依赖的类
package com.asto.dmp.articlecate.biz
import com.asto.dmp.articlecate.base.Props
import com.asto.dmp.articlecate.utils.FileUtils
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import com.asto.dmp.articlecate.biz.ClsFeaturesParser._
import scala.collection._
class MLLibRandomForestModel(val sc: SparkContext, val modelPath: String) extends scala.Serializable with Logging {
def genRandomForestModel(svmTrainDataPath: String) = {
val numClasses = ClsFeaturesParser.clsNameToCodeMap.size //Util.parseMapFrom(clsIndicesPath, nameToCode = true).size
val categoricalFeaturesInfo = immutable.Map[Int, Int]()
val numTrees = Props.get("model_numTrees").toInt
val featureSubsetStrategy = Props.get("model_featureSubsetStrategy") // Let the algorithm choose.
val impurity = Props.get("model_impurity")
val maxDepth = Props.get("model_maxDepth").toInt
val maxBins = Props.get("model_maxBins").toInt
val trainingData = MLUtils.loadLibSVMFile(sc, svmTrainDataPath).cache()
val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
FileUtils.deleteFilesInHDFS(modelPath)
model.save(sc, modelPath)
testErrorRate(trainingData, model)
}
private def testErrorRate(trainingData: RDD[LabeledPoint], model: RandomForestModel) = {
if (Props.get("model_test").toBoolean) {
val testData = trainingData.sample(false, Props.get("model_sampleRate").toDouble)
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testError = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
logInfo(s"????????????$testError")
} else {
logInfo(s"???????????")
}
}
def predictAndSave(lineAndVectors: Array[(String, org.apache.spark.mllib.linalg.Vector)], resultPath: String) = {
val model = RandomForestModel.load(sc, modelPath)
val result = lineAndVectors.map(lv => (s"${clsCodeToNameMap(model.predict(lv._2).toInt.toString)}\t${lv._1}")).mkString("\n")
FileUtils.saveFileToHDFS(resultPath, result)
}
}
开发者ID:luciuschina,项目名称:ArticleCategories,代码行数:56,代码来源:MLLibRandomForestModel.scala
示例2: RandomForestAlgorithmParams
//设置package包名称以及导入依赖的类
package org.template.classification
import org.apache.predictionio.controller.P2LAlgorithm
import org.apache.predictionio.controller.Params
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.SparkContext
import grizzled.slf4j.Logger
case class RandomForestAlgorithmParams(
numClasses: Int,
numTrees: Int,
featureSubsetStrategy: String,
impurity: String,
maxDepth: Int,
maxBins: Int
) extends Params
// extends P2LAlgorithm because the MLlib's RandomForestAlgorithm doesn't contain RDD.
class RandomForestAlgorithm(val ap: RandomForestAlgorithmParams)
extends P2LAlgorithm[PreparedData, RandomForestModel, Query, PredictedResult] {
@transient lazy val logger = Logger[this.type]
def train(sc: SparkContext, data: PreparedData): RandomForestModel = {// Empty categoricalFeaturesInfo indicates all features are continuous.
val categoricalFeaturesInfo = Map[Int, Int]()
RandomForest.trainClassifier(
data.labeledPoints,
ap.numClasses,
categoricalFeaturesInfo,
ap.numTrees,
ap.featureSubsetStrategy,
ap.impurity,
ap.maxDepth,
ap.maxBins)
}
def predict(model: RandomForestModel, query: Query): PredictedResult = {
val features = Vectors.dense(
Array(query.voice_usage, query.data_usage, query.text_usage)
)
val label = model.predict(features)
new PredictedResult(label)
}
}
开发者ID:wmfongsf,项目名称:predictionio-engine-classification,代码行数:50,代码来源:RandomForestAlgorithm.scala
示例3: RandomForestAlgorithmTest
//设置package包名称以及导入依赖的类
package org.template.classification
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.scalatest.FlatSpec
import org.scalatest.Matchers
class RandomForestAlgorithmTest
extends FlatSpec with SharedSingletonContext with Matchers {
val params = RandomForestAlgorithmParams(
numClasses = 7,
numTrees = 12,
featureSubsetStrategy = "auto",
impurity = "gini",
maxDepth = 4,
maxBins = 100)
val algorithm = new RandomForestAlgorithm(params)
val dataSource = Seq(
LabeledPoint(0, Vectors.dense(1000, 10, 10)),
LabeledPoint(1, Vectors.dense(10, 1000, 10)),
LabeledPoint(2, Vectors.dense(10, 10, 1000))
)
"train" should "return RandomForest model" in {
val dataSourceRDD = sparkContext.parallelize(dataSource)
val preparedData = new PreparedData(labeledPoints = dataSourceRDD)
val model = algorithm.train(sparkContext, preparedData)
model shouldBe a [RandomForestModel]
}
}
开发者ID:wmfongsf,项目名称:predictionio-engine-classification,代码行数:36,代码来源:RandomForestAlgorithmTest.scala
示例4: Predictor
//设置package包名称以及导入依赖的类
package core
import kafka.utils.Json
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.streaming.StreamingContext
import scala.collection.parallel.mutable.ParSeq
object Predictor {
var models: ParSeq[(String, RandomForestModel)] = null
def setUpModels(ssc: StreamingContext, models: ParSeq[(String, RandomForestModel)]) = {
this.models=models
}
def getPredictions(v: Vector) = {
models.map(model => {
val pred = model._2.predict(v)
(
model._1,
pred,
Json.encode(Map("modelName"->model._1,"prediction"->pred)),
Json.encode(
Map(
"modelName" -> model._1,
"numTrees" -> model._2.numTrees,
"totalNodes" -> model._2.totalNumNodes,
"prediction" -> pred,
"trees" -> model._2.trees.par.map(tree =>
Map("nodes" -> tree.numNodes, "prediction" -> tree.predict(v))).toArray
)
)
)
}
)
}
}
开发者ID:jandion,项目名称:SparkOFP,代码行数:41,代码来源:Predictor.scala
注:本文中的org.apache.spark.mllib.tree.model.RandomForestModel类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论