本文整理汇总了Java中org.apache.spark.mllib.classification.SVMWithSGD类的典型用法代码示例。如果您正苦于以下问题:Java SVMWithSGD类的具体用法?Java SVMWithSGD怎么用?Java SVMWithSGD使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
SVMWithSGD类属于org.apache.spark.mllib.classification包,在下文中一共展示了SVMWithSGD类的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: trainWithSGD
import org.apache.spark.mllib.classification.SVMWithSGD; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public T trainWithSGD(int numIterations){
//Train the model
if(modelName.equals("SVMModel")){
SVMModel svmmodel = SVMWithSGD.train(trainingData.rdd(), numIterations);
this.model = (T)(Object) svmmodel;
}
else if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = LogisticRegressionWithSGD.train(trainingData.rdd(), numIterations);
this.model = (T)(Object) lrmodel;
}
//Evalute the trained model
EvaluateProcess<T> evalProcess = new EvaluateProcess<T>(model, modelName, validData, numClasses);
evalProcess.evalute(numClasses);
return model;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:18,代码来源:TrainModel.java
示例2: generateKMeansModel
import org.apache.spark.mllib.classification.SVMWithSGD; //导入依赖的package包/类
public SVMModel generateKMeansModel(JavaRDD<LabeledPoint> parsedData,
SVMDetectionAlgorithm svmDetectionAlgorithm,
SVMModelSummary SVMModelSummary) {
SVMModel svmModel;
if (svmDetectionAlgorithm.getMiniBatchFraction() != -1) {
svmModel = SVMWithSGD.train(parsedData.rdd(),
svmDetectionAlgorithm.getNumIterations(),
svmDetectionAlgorithm.getStepSize(),
svmDetectionAlgorithm.getRegParam(),
svmDetectionAlgorithm.getMiniBatchFraction());
}else if (svmDetectionAlgorithm.getRegParam() != -1) {
svmModel = SVMWithSGD.train(parsedData.rdd(),
svmDetectionAlgorithm.getNumIterations(),
svmDetectionAlgorithm.getStepSize(),
svmDetectionAlgorithm.getRegParam());
}else {
svmModel = SVMWithSGD.train(parsedData.rdd(),
svmDetectionAlgorithm.getNumIterations());
}
SVMModelSummary.setSVMDetectionAlgorithm(svmDetectionAlgorithm);
return svmModel;
}
开发者ID:shlee89,项目名称:athena,代码行数:27,代码来源:SVMDistJob.java
示例3: ModelSVM
import org.apache.spark.mllib.classification.SVMWithSGD; //导入依赖的package包/类
public ModelSVM(JavaRDD<LabeledPoint> training) {
super();
SVMWithSGD svmAlg = new SVMWithSGD();
svmAlg.optimizer().setNumIterations(100).setRegParam(0.1).setUpdater(new L1Updater());
model = svmAlg.run(training.rdd());
// Clear the default threshold.
// model.clearThreshold();
// model.setThreshold(0.001338428);
}
开发者ID:mhardalov,项目名称:news-credibility,代码行数:12,代码来源:ModelSVM.java
示例4: train
import org.apache.spark.mllib.classification.SVMWithSGD; //导入依赖的package包/类
/**
* This method uses stochastic gradient descent (SGD) algorithm to train a support vector machine (SVM) model.
*
* @param trainingDataset Training dataset as a JavaRDD of LabeledPoints
* @param noOfIterations Number of iterarations
* @param regularizationType Regularization type: L1 or L2
* @param regularizationParameter Regularization parameter
* @param initialLearningRate Initial learning rate (SGD step size)
* @param miniBatchFraction SGD minibatch fraction
* @return SVM model
*/
public SVMModel train(JavaRDD<LabeledPoint> trainingDataset, int noOfIterations, String regularizationType,
double regularizationParameter, double initialLearningRate, double miniBatchFraction) {
SVMWithSGD svmWithSGD = new SVMWithSGD();
if (regularizationType.equals(MLConstants.L1)) {
svmWithSGD.optimizer().setUpdater(new L1Updater()).setRegParam(regularizationParameter);
} else if (regularizationType.equals((MLConstants.L2))) {
svmWithSGD.optimizer().setUpdater(new SquaredL2Updater()).setRegParam(regularizationParameter);
}
svmWithSGD.optimizer().setNumIterations(noOfIterations).setStepSize(initialLearningRate)
.setMiniBatchFraction(miniBatchFraction);
return svmWithSGD.run(trainingDataset.rdd());
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:24,代码来源:SVM.java
示例5: main
import org.apache.spark.mllib.classification.SVMWithSGD; //导入依赖的package包/类
public static void main(String[] args) {
MudrodEngine me = new MudrodEngine();
JavaSparkContext jsc = me.startSparkDriver().sc;
String path = SparkSVM.class.getClassLoader().getResource("inputDataForSVM_spark.txt").toString();
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
// Run training algorithm to build the model.
int numIterations = 100;
final SVMModel model = SVMWithSGD.train(data.rdd(), numIterations);
// Save and load model
model.save(jsc.sc(), SparkSVM.class.getClassLoader().getResource("javaSVMWithSGDModel").toString());
jsc.sc().stop();
}
开发者ID:apache,项目名称:incubator-sdap-mudrod,代码行数:19,代码来源:SparkSVM.java
示例6: main
import org.apache.spark.mllib.classification.SVMWithSGD; //导入依赖的package包/类
public static void main(String[] args) throws IOException {
JavaSparkContext sc = new JavaSparkContext("local", "WikipediaKMeans");
JavaRDD<String> lines = sc.textFile("data/" + input_file);
JavaRDD<LabeledPoint> points = lines.map(new ParsePoint());
// Split initial RDD into two with 70% training data and 30% testing data (13L is a random seed):
JavaRDD<LabeledPoint>[] splits = points.randomSplit(new double[]{0.7, 0.3}, 13L);
JavaRDD<LabeledPoint> training = splits[0].cache();
JavaRDD<LabeledPoint> testing = splits[1];
training.cache();
// Building the model
int numIterations = 500;
final SVMModel model =
SVMWithSGD.train(JavaRDD.toRDD(training), numIterations);
model.clearThreshold();
// Evaluate model on testing examples and compute training error
JavaRDD<Tuple2<Double, Double>> valuesAndPreds = testing.map(
new Function<LabeledPoint, Tuple2<Double, Double>>() {
public Tuple2<Double, Double> call(LabeledPoint point) {
double prediction = model.predict(point.features());
System.out.println(" ++ prediction: " + prediction + " original: " + map_to_print_original_text.get(point.features().compressed().toString()));
return new Tuple2<Double, Double>(prediction, point.label());
}
}
);
double MSE = new JavaDoubleRDD(valuesAndPreds.map(
new Function<Tuple2<Double, Double>, Object>() {
public Object call(Tuple2<Double, Double> pair) {
return Math.pow(pair._1() - pair._2(), 2.0);
}
}
).rdd()).mean();
System.out.println("Test Data Mean Squared Error = " + MSE);
sc.stop();
}
开发者ID:mark-watson,项目名称:power-java,代码行数:42,代码来源:SvmTextClassifier.java
示例7: trainInternal
import org.apache.spark.mllib.classification.SVMWithSGD; //导入依赖的package包/类
@Override
protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
throws LensException {
SVMModel svmModel = SVMWithSGD.train(trainingRDD, iterations, stepSize, regParam, minBatchFraction);
return new SVMClassificationModel(modelId, svmModel);
}
开发者ID:apache,项目名称:lens,代码行数:7,代码来源:SVMAlgo.java
注:本文中的org.apache.spark.mllib.classification.SVMWithSGD类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论