本文整理汇总了Java中org.apache.spark.mllib.classification.LogisticRegressionModel类的典型用法代码示例。如果您正苦于以下问题:Java LogisticRegressionModel类的具体用法?Java LogisticRegressionModel怎么用?Java LogisticRegressionModel使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
LogisticRegressionModel类属于org.apache.spark.mllib.classification包,在下文中一共展示了LogisticRegressionModel类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: predictForMetrics
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public JavaRDD<Tuple2<Object, Object>> predictForMetrics(String modelName, T model, JavaRDD<LabeledPoint> data, int numClasses){
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = null;
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = (LogisticRegressionModel) model;
if(numClasses==2){
lrmodel.clearThreshold();
}
//Predict
predictionAndLabels = PredictUnit.predictForMetrics_LogisticRegressionModel(lrmodel, data);
}
else if(modelName.equals("SVMModel")){
SVMModel svmmodel = (SVMModel) model;
if(numClasses==2){
svmmodel.clearThreshold();
}
//Predict
predictionAndLabels = PredictUnit.predictForMetrics_SVMModel(svmmodel, data);
}
else if(modelName.equals("NaiveBayesModel")){
NaiveBayesModel bayesmodel = (NaiveBayesModel) model;
//Predict
predictionAndLabels = PredictUnit.predictForMetrics_NaiveBayesModel(bayesmodel, data);
}
return predictionAndLabels;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:26,代码来源:PredictUnit.java
示例2: predictForOutput
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public JavaRDD<Tuple2<Object, Object>> predictForOutput(String modelName, T model, JavaRDD<LabeledPoint> data, int numClasses, double threshold){
JavaRDD<Tuple2<Object, Object>> FeaturesAndPrediction = null;
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = (LogisticRegressionModel) model;
if(numClasses==2){
lrmodel.setThreshold(threshold);
}
//Predict
FeaturesAndPrediction = PredictUnit.predictForOutput_LogisticRegressionModel(lrmodel, data);
}
else if(modelName.equals("SVMModel")){
SVMModel svmmodel = (SVMModel) model;
if(numClasses==2){
svmmodel.setThreshold(threshold);
}
//Predict
FeaturesAndPrediction = PredictUnit.predictForOutput_SVMModel(svmmodel, data);
}
else if(modelName.equals("NaiveBayesModel")){
NaiveBayesModel bayesmodel = (NaiveBayesModel) model;
//Predict
FeaturesAndPrediction = PredictUnit.predictForOutput_NaiveBayesModel(bayesmodel, data);
}
return FeaturesAndPrediction;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:27,代码来源:PredictUnit.java
示例3: getModelInfo
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) {
final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray());
logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get());
Set<String> inputKeys = new LinkedHashSet<String>();
inputKeys.add("features");
logisticRegressionModelInfo.setInputKeys(inputKeys);
Set<String> outputKeys = new LinkedHashSet<String>();
outputKeys.add("prediction");
outputKeys.add("probability");
logisticRegressionModelInfo.setOutputKeys(outputKeys);
return logisticRegressionModelInfo;
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:21,代码来源:LogisticRegressionModelInfoAdapter.java
示例4: predictForOutput_LogisticRegressionModel
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public static JavaRDD<Tuple2<Object, Object>> predictForOutput_LogisticRegressionModel(LogisticRegressionModel model, JavaRDD<LabeledPoint> data){
JavaRDD<Tuple2<Object, Object>> FeaturesAndPrediction = data.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
private static final long serialVersionUID = 1L;
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(p.features(), prediction);
}
}
);
return FeaturesAndPrediction;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:13,代码来源:PredictUnit.java
示例5: PredictWithModel
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public PredictWithModel(String modelName, String modelPath, String testFile, int numClasses, int minPartition, double threshold, SparkContext sc){
this.numClasses = numClasses;
this.threshold = threshold;
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = LogisticRegressionModel.load(sc, modelPath);
this.model = (T)(Object) lrmodel;
}
else if(modelName.equals("SVMModel")){
SVMModel svmmodel = SVMModel.load(sc, modelPath);
this.model = (T)(Object) svmmodel;
}
else if(modelName.equals("NaiveBayesModel")){
NaiveBayesModel bayesmodel = NaiveBayesModel.load(sc, modelPath);
this.model = (T)(Object) bayesmodel;
}
//Load testing data
LoadProcess loadProcess = new LoadProcess(sc, minPartition);
testingData = loadProcess.load(testFile, "Vector");
testingData.cache();
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:24,代码来源:PredictWithModel.java
示例6: trainWithLBFGS
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public T trainWithLBFGS(){
//Train the model
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = new LogisticRegressionWithLBFGS()
.setNumClasses(numClasses)
.run(trainingData.rdd());
System.out.println("\n--------------------------------------\n weights: " + lrmodel.weights());
System.out.println("--------------------------------------\n");
this.model = (T)(Object) lrmodel;
}
//Evalute the trained model
EvaluateProcess<T> evalProcess = new EvaluateProcess<T>(model, modelName, validData, numClasses);
evalProcess.evalute(numClasses);
return model;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:21,代码来源:TrainModel.java
示例7: trainWithSGD
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public T trainWithSGD(int numIterations){
//Train the model
if(modelName.equals("SVMModel")){
SVMModel svmmodel = SVMWithSGD.train(trainingData.rdd(), numIterations);
this.model = (T)(Object) svmmodel;
}
else if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = LogisticRegressionWithSGD.train(trainingData.rdd(), numIterations);
this.model = (T)(Object) lrmodel;
}
//Evalute the trained model
EvaluateProcess<T> evalProcess = new EvaluateProcess<T>(model, modelName, validData, numClasses);
evalProcess.evalute(numClasses);
return model;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:18,代码来源:TrainModel.java
示例8: generateDecisionTreeWithPreprocessing
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public LogisticRegressionModel generateDecisionTreeWithPreprocessing(JavaPairRDD<Object, BSONObject> mongoRDD,
AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
LogisticRegressionDetectionAlgorithm logisticRegressionDetectionAlgorithm,
Marking marking,
LogisticRegressionModelSummary logisticRegressionModelSummary) {
return generateKMeansModel(
rddPreProcessing(mongoRDD, athenaMLFeatureConfiguration, logisticRegressionModelSummary,
marking),
logisticRegressionDetectionAlgorithm, logisticRegressionModelSummary
);
}
开发者ID:shlee89,项目名称:athena,代码行数:13,代码来源:LogisticRegressionDistJob.java
示例9: OnlineFeatureHandler
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public OnlineFeatureHandler(FeatureConstraint featureConstraint,
DetectionModel detectionModel,
onlineMLEventListener onlineMLEventListener,
ControllerConnector controllerConnector) {
this.featureConstraint = featureConstraint;
this.detectionModel = detectionModel;
setAthenaMLFeatureConfiguration(detectionModel.getAthenaMLFeatureConfiguration());
if (detectionModel instanceof KMeansDetectionModel) {
this.kMeansModel = (KMeansModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof GaussianMixtureDetectionModel) {
this.gaussianMixtureModel = (GaussianMixtureModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof DecisionTreeDetectionModel) {
this.decisionTreeModel = (DecisionTreeModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof NaiveBayesDetectionModel) {
this.naiveBayesModel = (NaiveBayesModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof RandomForestDetectionModel) {
this.randomForestModel = (RandomForestModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof GradientBoostedTreesDetectionModel) {
this.gradientBoostedTreesModel = (GradientBoostedTreesModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof SVMDetectionModel) {
this.svmModel = (SVMModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LogisticRegressionDetectionModel) {
this.logisticRegressionModel = (LogisticRegressionModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LinearRegressionDetectionModel) {
this.linearRegressionModel = (LinearRegressionModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LassoDetectionModel) {
this.lassoModel = (LassoModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof RidgeRegressionDetectionModel) {
this.ridgeRegressionModel = (RidgeRegressionModel) detectionModel.getDetectionModel();
} else {
//not supported ML model
System.out.println("Not supported model");
}
this.eventDeliveryManager = new EventDeliveryManagerImpl(controllerConnector, new InternalAthenaFeatureEventListener());
this.eventDeliveryManager.registerOnlineAthenaFeature(null, new QueryIdentifier(QUERY_IDENTIFIER), featureConstraint);
this.onlineMLEventListener = onlineMLEventListener;
System.out.println("Install handler!");
}
开发者ID:shlee89,项目名称:athena,代码行数:41,代码来源:OnlineFeatureHandler.java
示例10: getModelInfo
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel) {
final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray());
logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get());
Set<String> inputKeys = new LinkedHashSet<String>();
inputKeys.add("features");
logisticRegressionModelInfo.setInputKeys(inputKeys);
Set<String> outputKeys = new LinkedHashSet<String>();
outputKeys.add("prediction");
outputKeys.add("probability");
logisticRegressionModelInfo.setOutputKeys(outputKeys);
return logisticRegressionModelInfo;
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:21,代码来源:LogisticRegressionModelInfoAdapter.java
示例11: shouldExportAndImportCorrectly
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel);
//Import it back
LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);
//check if they are exactly equal with respect to their fields
//it maybe edge cases eg. order of elements in the list is changed
assertEquals(lrmodel.intercept(), importedModel.getIntercept(), 0.01);
assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), 0.01);
assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), 0.01);
assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), 0.01);
for (int i = 0; i < importedModel.getNumFeatures(); i++)
assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], 0.01);
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:25,代码来源:LogisticRegressionExporterTest.java
示例12: shouldExportAndImportCorrectly
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel, null);
//Import it back
LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);
//check if they are exactly equal with respect to their fields
//it maybe edge cases eg. order of elements in the list is changed
assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON);
assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON);
assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON);
assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), EPSILON);
for (int i = 0; i < importedModel.getNumFeatures(); i++)
assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON);
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:25,代码来源:LogisticRegressionExporterTest.java
示例13: predictForMetrics_LogisticRegressionModel
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public static JavaRDD<Tuple2<Object, Object>> predictForMetrics_LogisticRegressionModel(LogisticRegressionModel model, JavaRDD<LabeledPoint> data){
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = data.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
private static final long serialVersionUID = 1L;
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(prediction, p.label());
}
}
);
return predictionAndLabels;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:13,代码来源:PredictUnit.java
示例14: generateKMeansModel
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public LogisticRegressionModel generateKMeansModel(JavaRDD<LabeledPoint> parsedData,
LogisticRegressionDetectionAlgorithm logisticRegressionDetectionAlgorithm,
LogisticRegressionModelSummary logisticRegressionModelSummary) {
LogisticRegressionModel model
= new LogisticRegressionWithLBFGS()
.setNumClasses(logisticRegressionDetectionAlgorithm.getNumClasses())
.run(parsedData.rdd());
logisticRegressionModelSummary.setLogisticRegressionDetectionAlgorithm(logisticRegressionDetectionAlgorithm);
return model;
}
开发者ID:shlee89,项目名称:athena,代码行数:12,代码来源:LogisticRegressionDistJob.java
示例15: testLogisticRegression
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
//prepare data
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel);
//Import and get Transformer
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
//validate predictions
List<LabeledPoint> testPoints = trainingData.collect();
for (LabeledPoint i : testPoints) {
Vector v = i.features();
double actual = lrmodel.predict(v);
Map<String, Object> data = new HashMap<String, Object>();
data.put("features", v.toArray());
transformer.transform(data);
double predicted = (double) data.get("prediction");
assertEquals(actual, predicted, 0.01);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:30,代码来源:LogisticRegressionBridgeTest.java
示例16: testLogisticRegression
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
//prepare data
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel, null);
//Import and get Transformer
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
//validate predictions
List<LabeledPoint> testPoints = trainingData.collect();
for (LabeledPoint i : testPoints) {
Vector v = i.features();
double actual = lrmodel.predict(v);
Map<String, Object> data = new HashMap<String, Object>();
data.put("features", v.toArray());
transformer.transform(data);
double predicted = (double) data.get("prediction");
assertEquals(actual, predicted, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:30,代码来源:LogisticRegressionBridgeTest.java
示例17: instantiateSparkModel
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
private LogisticRegressionModel instantiateSparkModel() {
Configuration conf = new Configuration();
conf.set("fs.defaultFS", topologyConfig.getProperty("hdfs.url"));
double[] sparkModelInfo = null;
try {
sparkModelInfo = getSparkModelInfoFromHDFS(new Path(topologyConfig.getProperty("hdfs.url") +
"/tmp/sparkML_weights"), conf);
} catch (Exception e) {
LOG.error("Couldn't instantiate Spark model in prediction bolt: " + e.getMessage());
e.printStackTrace();
throw new RuntimeException(e);
}
// all numbers besides the last value are the weights
double[] weights = Arrays.copyOfRange(sparkModelInfo, 0, sparkModelInfo.length - 1);
// the last number in the array is the intercept
double intercept = sparkModelInfo[sparkModelInfo.length - 1];
org.apache.spark.mllib.linalg.Vector weightsV = (Vectors.dense(weights));
return new LogisticRegressionModel(weightsV, intercept);
}
开发者ID:DhruvKumar,项目名称:iot-masterclass,代码行数:28,代码来源:PredictionBolt.java
示例18: trainWithSGD
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
/**
* TODO add another overloaded method to avoid Regularization.
* This method uses stochastic gradient descent (SGD) algorithm to train a logistic regression model
*
* @param trainingDataset Training dataset as a JavaRDD of labeled points
* @param noOfIterations No of iterations
* @param initialLearningRate Initial learning rate
* @param regularizationType Regularization type : L1 or L2
* @param regularizationParameter Regularization parameter
* @param dataFractionPerSGDIteration Data fraction per SGD iteration
* @return Logistic regression model
*/
public LogisticRegressionModel trainWithSGD(JavaRDD<LabeledPoint> trainingDataset, double initialLearningRate,
int noOfIterations, String regularizationType, double regularizationParameter,
double dataFractionPerSGDIteration) {
LogisticRegressionWithSGD lrSGD = new LogisticRegressionWithSGD(initialLearningRate, noOfIterations,
regularizationParameter, dataFractionPerSGDIteration);
if (MLConstants.L1.equals(regularizationType)) {
lrSGD.optimizer().setUpdater(new L1Updater());
} else if (MLConstants.L2.equals(regularizationType)) {
lrSGD.optimizer().setUpdater(new SquaredL2Updater());
}
lrSGD.setIntercept(true);
return lrSGD.run(trainingDataset.rdd());
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:26,代码来源:LogisticRegression.java
示例19: trainWithLBFGS
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
/**
* This method uses LBFGS optimizer to train a logistic regression model for a given dataset
*
* @param trainingDataset Training dataset as a JavaRDD of labeled points
* @param noOfClasses No of classes
* @param regularizationType Regularization type
* @return Logistic regression model
*/
public LogisticRegressionModel trainWithLBFGS(JavaRDD<LabeledPoint> trainingDataset, String regularizationType,
int noOfClasses) {
LogisticRegressionWithLBFGS lbfgs = new LogisticRegressionWithLBFGS();
if (MLConstants.L1.equals(regularizationType)) {
lbfgs.optimizer().setUpdater(new L1Updater());
} else if (MLConstants.L2.equals(regularizationType)) {
lbfgs.optimizer().setUpdater(new SquaredL2Updater());
}
lbfgs.setIntercept(true);
return lbfgs.setNumClasses(noOfClasses < 2 ? 2 : noOfClasses).run(trainingDataset.rdd());
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:20,代码来源:LogisticRegression.java
示例20: test
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
/**
* This method performs a binary classification using a given model and a dataset
*
* @param logisticRegressionModel Logistic regression model
* @param testingDataset Testing dataset as a JavaRDD of LabeledPoints
* @return Tuple2 containing scores and labels
*/
public JavaRDD<Tuple2<Object, Object>> test(final LogisticRegressionModel logisticRegressionModel,
JavaRDD<LabeledPoint> testingDataset) {
return testingDataset.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
private static final long serialVersionUID = 910861043765821669L;
public Tuple2<Object, Object> call(LabeledPoint labeledPoint) {
Double score = logisticRegressionModel.predict(labeledPoint.features());
return new Tuple2<Object, Object>(score, labeledPoint.label());
}
}
);
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:21,代码来源:LogisticRegression.java
注:本文中的org.apache.spark.mllib.classification.LogisticRegressionModel类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论