• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Scala SVMWithSGD类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Scala中org.apache.spark.mllib.classification.SVMWithSGD的典型用法代码示例。如果您正苦于以下问题:Scala SVMWithSGD类的具体用法?Scala SVMWithSGD怎么用?Scala SVMWithSGD使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



在下文中一共展示了SVMWithSGD类的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Scala代码示例。

示例1: SVMPipeline

//设置package包名称以及导入依赖的类
package org.stumbleuponclassifier

import org.apache.log4j.Logger
import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint


object SVMPipeline {
  @transient lazy val logger = Logger.getLogger(getClass.getName)

  def svmPipeline(sc: SparkContext) = {
    val records = sc.textFile("/home/ubuntu/work/ml-resources/spark-ml/train_noheader.tsv").map(line => line.split("\t"))

    val data = records.map { r =>
      val trimmed = r.map(_.replaceAll("\"", ""))
      val label = trimmed(r.size - 1).toInt
      val features = trimmed.slice(4, r.size - 1).map(d => if (d == "?") 0.0 else d.toDouble)
      LabeledPoint(label, Vectors.dense(features))
    }

    // params for SVM
    val numIterations = 10

    // Run training algorithm to build the model
    val svmModel = SVMWithSGD.train(data, numIterations)

    // Clear the default threshold.
    svmModel.clearThreshold()

    val svmTotalCorrect = data.map { point =>
      if(svmModel.predict(point.features) == point.label) 1 else 0
    }.sum()

    // calculate accuracy
    val svmAccuracy = svmTotalCorrect / data.count()
    println(svmAccuracy)
  }

} 
开发者ID:PacktPublishing,项目名称:Machine-Learning-with-Spark-Second-Edition,代码行数:42,代码来源:SVMPipeline.scala


示例2: MllibSGD

//设置package包名称以及导入依赖的类
package optimizers

import breeze.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD}
import org.apache.spark.mllib.optimization.{L1Updater, SimpleUpdater, SquaredL2Updater, Updater}
import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}
import org.apache.spark.rdd.RDD
import utils.Functions._

import scala.tools.cmd.gen.AnyVals.D




class MllibSGD(val data: RDD[LabeledPoint],
               loss: LossFunction,
               regularizer: Regularizer,
               params: SGDParameters,
               ctype: String
              ) extends Optimizer(loss, regularizer) {
  val opt = ctype match {
    case "SVM" => new SVMWithSGD()
    case "LR" => new LogisticRegressionWithSGD()
    case "Regression" => new LinearRegressionWithSGD()
  }

  val reg: Updater = (regularizer: Regularizer) match {
    case _: L1Regularizer => new L1Updater
    case _: L2Regularizer => new SquaredL2Updater
    case _: Unregularized => new SimpleUpdater
  }

  ctype match {
    case "SVM" => opt.asInstanceOf[SVMWithSGD].optimizer.
      setNumIterations(params.iterations).
      setMiniBatchFraction(params.miniBatchFraction).
      setStepSize(params.stepSize).
      setRegParam(regularizer.lambda).
      setUpdater(reg)
    case "LR" => opt.asInstanceOf[LogisticRegressionWithSGD].optimizer.
      setNumIterations(params.iterations).
      setMiniBatchFraction(params.miniBatchFraction).
      setStepSize(params.stepSize).
      setRegParam(regularizer.lambda).
      setUpdater(reg)
    case "Regression" => opt.asInstanceOf[LinearRegressionWithSGD].optimizer.
      setNumIterations(params.iterations).
      setMiniBatchFraction(params.miniBatchFraction).
      setStepSize(params.stepSize).
      setRegParam(regularizer.lambda).
      setUpdater(reg)
  }

  override def optimize(): Vector[Double] = {
    val model = opt.run(data)
    val w = model.weights.toArray
    DenseVector(w)
  }
} 
开发者ID:mlbench,项目名称:mlbench,代码行数:60,代码来源:MllibSGD.scala


示例3: SVMPipeline

//设置package包名称以及导入依赖的类
package org.sparksamples.classification.stumbleupon

import org.apache.log4j.Logger
import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint


object SVMPipeline {
  @transient lazy val logger = Logger.getLogger(getClass.getName)

  def svmPipeline(sc: SparkContext) = {
    val records = sc.textFile("/home/ubuntu/work/ml-resources/spark-ml/train_noheader.tsv").map(line => line.split("\t"))

    val data = records.map { r =>
      val trimmed = r.map(_.replaceAll("\"", ""))
      val label = trimmed(r.size - 1).toInt
      val features = trimmed.slice(4, r.size - 1).map(d => if (d == "?") 0.0 else d.toDouble)
      LabeledPoint(label, Vectors.dense(features))
    }

    // params for SVM
    val numIterations = 10

    // Run training algorithm to build the model
    val svmModel = SVMWithSGD.train(data, numIterations)

    // Clear the default threshold.
    svmModel.clearThreshold()

    val svmTotalCorrect = data.map { point =>
      if(svmModel.predict(point.features) == point.label) 1 else 0
    }.sum()

    // calculate accuracy
    val svmAccuracy = svmTotalCorrect / data.count()
    println(svmAccuracy)
  }

} 
开发者ID:PacktPublishing,项目名称:Machine-Learning-with-Spark-Second-Edition,代码行数:42,代码来源:SVMPipeline.scala


示例4: SVMTest

//设置package包名称以及导入依赖的类
package cn.edu.bjtu


import org.apache.spark.SparkConf
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SparkSession

object SVMTest {
  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf()
      .setAppName("SVMTest")
      .setMaster("spark://master:7077")
      .setJars(Array("/home/hadoop/SVM.jar"))

    val spark = SparkSession.builder()
      .config(sparkConf)
      .getOrCreate()

    spark.sparkContext.setLogLevel("WARN")

    val data = MLUtils.loadLibSVMFile(spark.sparkContext, "hdfs://master:9000/sample_formatted.txt")
    // Split data into training (80%) and test (20%).
    val splits = data.randomSplit(Array(0.7, 0.3), seed = 11L)
    val training = splits(0).cache()
    val test = splits(1)

    // Run training algorithm to build the model
    val numIterations = 100
    val model = SVMWithSGD.train(training, numIterations)

    // Clear the default threshold.
    model.setThreshold(-5000)

    // Compute raw scores on the test set.
    val scoreAndLabels = test.map { point =>
      val score = model.predict(point.features)
      (score, point.label)
    }

    // Get evaluation metrics.
    val metrics = new BinaryClassificationMetrics(scoreAndLabels)
    val auROC = metrics.areaUnderROC()
    println("Area under ROC = " + auROC)
    println("Sensitivity = " + scoreAndLabels.filter(x => x._1 == x._2 && x._1 == 1.0).count().toDouble / scoreAndLabels.filter(x => x._2 == 1.0).count().toDouble)
    println("Specificity = " + scoreAndLabels.filter(x => x._1 == x._2 && x._1 == 0.0).count().toDouble / scoreAndLabels.filter(x => x._2 == 0.0).count().toDouble)
    println("Accuracy = " + scoreAndLabels.filter(x => x._1 == x._2).count().toDouble / scoreAndLabels.count().toDouble)
  }
} 
开发者ID:XiaoyuGuo,项目名称:DataFusionClass,代码行数:52,代码来源:SVMTest.scala


示例5: SVMOneVersusAllModel

//设置package包名称以及导入依赖的类
package com.scalafi.dynamics.svm

import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization.L1Updater
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.slf4j.LoggerFactory

class SVMOneVersusAllModel(val models: Map[Double, SVMModel]) {
  def predict(testData: Vector): Map[Double, Double] = {
    models.mapValues(_.predict(testData))
  }
}

class SVMOneVersusAll(svm: SVMWithSGD) {
  private val log = LoggerFactory.getLogger(classOf[SVMOneVersusAll])

  
  private def relabel(input: RDD[LabeledPoint], label: Double): RDD[LabeledPoint] = {
    input.map(point => if (point.label == label) point.copy(label = 1.0) else point.copy(label = 0.0))
  }

  def run(input: RDD[LabeledPoint]): SVMOneVersusAllModel = {
    val labels = input.map(_.label).collect().toSet

    log.info(s"Input labels number: ${labels.size}. Labels: [${labels.mkString(", ")}]")
    assume(labels.size > 2, s"Labels size should be greater than 2")

    val models = labels.map { label =>
      log.debug(s"Train model: '$label' versus [${labels.filterNot(_ == label).mkString(", ")}]")
      val model = svm.run(relabel(input, label))
      model.clearThreshold()
      (label, model)
    }

    new SVMOneVersusAllModel(models.toMap)
  }
}

object SVMOneVersusAll {
  def  train(input: RDD[LabeledPoint], numIterations: Int): SVMOneVersusAllModel = {
    val svmAlg = new SVMWithSGD()

    svmAlg.optimizer.
      setNumIterations(numIterations)

    new SVMOneVersusAll(svmAlg).run(input)
  }
} 
开发者ID:wensdong,项目名称:ML-test,代码行数:51,代码来源:SVMOneVersusAll.scala



注:本文中的org.apache.spark.mllib.classification.SVMWithSGD类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Scala TopicAndPartition类代码示例发布时间:2022-05-23
下一篇:
Scala Channels类代码示例发布时间:2022-05-23
热门推荐
热门话题
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap