• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Scala RandomForestModel类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Scala中org.apache.spark.mllib.tree.model.RandomForestModel的典型用法代码示例。如果您正苦于以下问题:Scala RandomForestModel类的具体用法?Scala RandomForestModel怎么用?Scala RandomForestModel使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



在下文中一共展示了RandomForestModel类的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Scala代码示例。

示例1: MLLibRandomForestModel

//设置package包名称以及导入依赖的类
package com.asto.dmp.articlecate.biz

import com.asto.dmp.articlecate.base.Props
import com.asto.dmp.articlecate.utils.FileUtils
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import com.asto.dmp.articlecate.biz.ClsFeaturesParser._
import scala.collection._

class MLLibRandomForestModel(val sc: SparkContext, val modelPath: String) extends scala.Serializable with Logging {

  def genRandomForestModel(svmTrainDataPath: String) = {
    val numClasses = ClsFeaturesParser.clsNameToCodeMap.size //Util.parseMapFrom(clsIndicesPath, nameToCode = true).size
    val categoricalFeaturesInfo = immutable.Map[Int, Int]()
    val numTrees = Props.get("model_numTrees").toInt
    val featureSubsetStrategy = Props.get("model_featureSubsetStrategy") // Let the algorithm choose.
    val impurity = Props.get("model_impurity")
    val maxDepth = Props.get("model_maxDepth").toInt
    val maxBins = Props.get("model_maxBins").toInt

    val trainingData = MLUtils.loadLibSVMFile(sc, svmTrainDataPath).cache()
    val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)

    FileUtils.deleteFilesInHDFS(modelPath)
    model.save(sc, modelPath)

    testErrorRate(trainingData, model)
  }

  
  private def testErrorRate(trainingData: RDD[LabeledPoint], model: RandomForestModel) = {
    if (Props.get("model_test").toBoolean) {
      val testData = trainingData.sample(false, Props.get("model_sampleRate").toDouble)
      val labelAndPreds = testData.map { point =>
        val prediction = model.predict(point.features)
        (point.label, prediction)
      }
      val testError = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
      logInfo(s"????????????$testError")
    } else {
      logInfo(s"???????????")
    }
  }

  def predictAndSave(lineAndVectors: Array[(String, org.apache.spark.mllib.linalg.Vector)], resultPath: String) = {
    val model = RandomForestModel.load(sc, modelPath)
    val result = lineAndVectors.map(lv => (s"${clsCodeToNameMap(model.predict(lv._2).toInt.toString)}\t${lv._1}")).mkString("\n")
    FileUtils.saveFileToHDFS(resultPath, result)
  }
} 
开发者ID:luciuschina,项目名称:ArticleCategories,代码行数:56,代码来源:MLLibRandomForestModel.scala


示例2: RandomForestAlgorithmParams

//设置package包名称以及导入依赖的类
package org.template.classification

import org.apache.predictionio.controller.P2LAlgorithm
import org.apache.predictionio.controller.Params

import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.SparkContext

import grizzled.slf4j.Logger

case class RandomForestAlgorithmParams(
  numClasses: Int,
  numTrees: Int,
  featureSubsetStrategy: String,
  impurity: String,
  maxDepth: Int,
  maxBins: Int
) extends Params

// extends P2LAlgorithm because the MLlib's RandomForestAlgorithm doesn't contain RDD.
class RandomForestAlgorithm(val ap: RandomForestAlgorithmParams)
  extends P2LAlgorithm[PreparedData, RandomForestModel, Query, PredictedResult] {

  @transient lazy val logger = Logger[this.type]

  def train(sc: SparkContext, data: PreparedData): RandomForestModel = {// Empty categoricalFeaturesInfo indicates all features are continuous.
    val categoricalFeaturesInfo = Map[Int, Int]()
    RandomForest.trainClassifier(
      data.labeledPoints,
      ap.numClasses,
      categoricalFeaturesInfo,
      ap.numTrees,
      ap.featureSubsetStrategy,
      ap.impurity,
      ap.maxDepth,
      ap.maxBins)
  }

  def predict(model: RandomForestModel, query: Query): PredictedResult = {
    val features = Vectors.dense(
      Array(query.voice_usage, query.data_usage, query.text_usage)
    )
    val label = model.predict(features)
    new PredictedResult(label)
  }

} 
开发者ID:wmfongsf,项目名称:predictionio-engine-classification,代码行数:50,代码来源:RandomForestAlgorithm.scala


示例3: RandomForestAlgorithmTest

//设置package包名称以及导入依赖的类
package org.template.classification

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel

import org.scalatest.FlatSpec
import org.scalatest.Matchers

class RandomForestAlgorithmTest
  extends FlatSpec with SharedSingletonContext with Matchers {

  val params = RandomForestAlgorithmParams(
    numClasses = 7,
    numTrees = 12,
    featureSubsetStrategy = "auto",
    impurity = "gini",
    maxDepth = 4,
    maxBins = 100)
  val algorithm = new RandomForestAlgorithm(params)

  val dataSource = Seq(
    LabeledPoint(0, Vectors.dense(1000, 10, 10)),
    LabeledPoint(1, Vectors.dense(10, 1000, 10)),
    LabeledPoint(2, Vectors.dense(10, 10, 1000))
  )

  "train" should "return RandomForest model" in {
    val dataSourceRDD = sparkContext.parallelize(dataSource)
    val preparedData = new PreparedData(labeledPoints = dataSourceRDD)
    val model = algorithm.train(sparkContext, preparedData)
    model shouldBe a [RandomForestModel]
  }
} 
开发者ID:wmfongsf,项目名称:predictionio-engine-classification,代码行数:36,代码来源:RandomForestAlgorithmTest.scala


示例4: Predictor

//设置package包名称以及导入依赖的类
package core

import kafka.utils.Json
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.streaming.StreamingContext

import scala.collection.parallel.mutable.ParSeq

object Predictor {


  var models: ParSeq[(String, RandomForestModel)] = null

  def setUpModels(ssc: StreamingContext, models: ParSeq[(String, RandomForestModel)]) = {
    this.models=models
  }

  def getPredictions(v: Vector) = {
    models.map(model => {
      val pred = model._2.predict(v)
      (
        model._1,
        pred,
        Json.encode(Map("modelName"->model._1,"prediction"->pred)),
        Json.encode(
        Map(
          "modelName" -> model._1,
          "numTrees" -> model._2.numTrees,
          "totalNodes" -> model._2.totalNumNodes,
          "prediction" -> pred,
          "trees" -> model._2.trees.par.map(tree =>
            Map("nodes" -> tree.numNodes, "prediction" -> tree.predict(v))).toArray
        )
        )
        )
    }
    )
  }
} 
开发者ID:jandion,项目名称:SparkOFP,代码行数:41,代码来源:Predictor.scala



注:本文中的org.apache.spark.mllib.tree.model.RandomForestModel类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Scala JsonFactory类代码示例发布时间:2022-05-23
下一篇:
Scala ReentrantReadWriteLock类代码示例发布时间:2022-05-23
热门推荐
热门话题
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap