本文整理汇总了Java中org.apache.spark.mllib.classification.NaiveBayesModel类的典型用法代码示例。如果您正苦于以下问题:Java NaiveBayesModel类的具体用法?Java NaiveBayesModel怎么用?Java NaiveBayesModel使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
NaiveBayesModel类属于org.apache.spark.mllib.classification包,在下文中一共展示了NaiveBayesModel类的16个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: predictForMetrics
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
public JavaRDD<Tuple2<Object, Object>> predictForMetrics(String modelName, T model, JavaRDD<LabeledPoint> data, int numClasses){
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = null;
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = (LogisticRegressionModel) model;
if(numClasses==2){
lrmodel.clearThreshold();
}
//Predict
predictionAndLabels = PredictUnit.predictForMetrics_LogisticRegressionModel(lrmodel, data);
}
else if(modelName.equals("SVMModel")){
SVMModel svmmodel = (SVMModel) model;
if(numClasses==2){
svmmodel.clearThreshold();
}
//Predict
predictionAndLabels = PredictUnit.predictForMetrics_SVMModel(svmmodel, data);
}
else if(modelName.equals("NaiveBayesModel")){
NaiveBayesModel bayesmodel = (NaiveBayesModel) model;
//Predict
predictionAndLabels = PredictUnit.predictForMetrics_NaiveBayesModel(bayesmodel, data);
}
return predictionAndLabels;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:26,代码来源:PredictUnit.java
示例2: predictForOutput
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
public JavaRDD<Tuple2<Object, Object>> predictForOutput(String modelName, T model, JavaRDD<LabeledPoint> data, int numClasses, double threshold){
JavaRDD<Tuple2<Object, Object>> FeaturesAndPrediction = null;
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = (LogisticRegressionModel) model;
if(numClasses==2){
lrmodel.setThreshold(threshold);
}
//Predict
FeaturesAndPrediction = PredictUnit.predictForOutput_LogisticRegressionModel(lrmodel, data);
}
else if(modelName.equals("SVMModel")){
SVMModel svmmodel = (SVMModel) model;
if(numClasses==2){
svmmodel.setThreshold(threshold);
}
//Predict
FeaturesAndPrediction = PredictUnit.predictForOutput_SVMModel(svmmodel, data);
}
else if(modelName.equals("NaiveBayesModel")){
NaiveBayesModel bayesmodel = (NaiveBayesModel) model;
//Predict
FeaturesAndPrediction = PredictUnit.predictForOutput_NaiveBayesModel(bayesmodel, data);
}
return FeaturesAndPrediction;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:27,代码来源:PredictUnit.java
示例3: PredictWithModel
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public PredictWithModel(String modelName, String modelPath, String testFile, int numClasses, int minPartition, double threshold, SparkContext sc){
this.numClasses = numClasses;
this.threshold = threshold;
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = LogisticRegressionModel.load(sc, modelPath);
this.model = (T)(Object) lrmodel;
}
else if(modelName.equals("SVMModel")){
SVMModel svmmodel = SVMModel.load(sc, modelPath);
this.model = (T)(Object) svmmodel;
}
else if(modelName.equals("NaiveBayesModel")){
NaiveBayesModel bayesmodel = NaiveBayesModel.load(sc, modelPath);
this.model = (T)(Object) bayesmodel;
}
//Load testing data
LoadProcess loadProcess = new LoadProcess(sc, minPartition);
testingData = loadProcess.load(testFile, "Vector");
testingData.cache();
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:24,代码来源:PredictWithModel.java
示例4: OnlineFeatureHandler
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
public OnlineFeatureHandler(FeatureConstraint featureConstraint,
DetectionModel detectionModel,
onlineMLEventListener onlineMLEventListener,
ControllerConnector controllerConnector) {
this.featureConstraint = featureConstraint;
this.detectionModel = detectionModel;
setAthenaMLFeatureConfiguration(detectionModel.getAthenaMLFeatureConfiguration());
if (detectionModel instanceof KMeansDetectionModel) {
this.kMeansModel = (KMeansModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof GaussianMixtureDetectionModel) {
this.gaussianMixtureModel = (GaussianMixtureModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof DecisionTreeDetectionModel) {
this.decisionTreeModel = (DecisionTreeModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof NaiveBayesDetectionModel) {
this.naiveBayesModel = (NaiveBayesModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof RandomForestDetectionModel) {
this.randomForestModel = (RandomForestModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof GradientBoostedTreesDetectionModel) {
this.gradientBoostedTreesModel = (GradientBoostedTreesModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof SVMDetectionModel) {
this.svmModel = (SVMModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LogisticRegressionDetectionModel) {
this.logisticRegressionModel = (LogisticRegressionModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LinearRegressionDetectionModel) {
this.linearRegressionModel = (LinearRegressionModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LassoDetectionModel) {
this.lassoModel = (LassoModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof RidgeRegressionDetectionModel) {
this.ridgeRegressionModel = (RidgeRegressionModel) detectionModel.getDetectionModel();
} else {
//not supported ML model
System.out.println("Not supported model");
}
this.eventDeliveryManager = new EventDeliveryManagerImpl(controllerConnector, new InternalAthenaFeatureEventListener());
this.eventDeliveryManager.registerOnlineAthenaFeature(null, new QueryIdentifier(QUERY_IDENTIFIER), featureConstraint);
this.onlineMLEventListener = onlineMLEventListener;
System.out.println("Install handler!");
}
开发者ID:shlee89,项目名称:athena,代码行数:41,代码来源:OnlineFeatureHandler.java
示例5: predictForMetrics_NaiveBayesModel
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
public static JavaRDD<Tuple2<Object, Object>> predictForMetrics_NaiveBayesModel(NaiveBayesModel model, JavaRDD<LabeledPoint> data){
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = data.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
private static final long serialVersionUID = 1L;
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(prediction, p.label());
}
}
);
return predictionAndLabels;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:13,代码来源:PredictUnit.java
示例6: predictForOutput_NaiveBayesModel
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
public static JavaRDD<Tuple2<Object, Object>> predictForOutput_NaiveBayesModel(NaiveBayesModel model, JavaRDD<LabeledPoint> data){
JavaRDD<Tuple2<Object, Object>> FeaturesAndPrediction = data.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
private static final long serialVersionUID = 1L;
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(p.features(), prediction);
}
}
);
return FeaturesAndPrediction;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:13,代码来源:PredictUnit.java
示例7: trainWithBayes
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public T trainWithBayes(){
//Train the model
NaiveBayesModel bayesmodel = NaiveBayes.train(trainingData.rdd(), 1.0); //This version of the method uses a default smoothing parameter of 1.0.
this.model = (T)(Object) bayesmodel;
//Evalute the trained model
EvaluateProcess<T> evalProcess = new EvaluateProcess<T>(model, modelName, validData, numClasses);
evalProcess.evalute(numClasses);
return model;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:12,代码来源:TrainModel.java
示例8: generateModel
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
public NaiveBayesModel generateModel(JavaRDD<LabeledPoint> parsedData,
NaiveBayesDetectionAlgorithm naiveBayesDetectionAlgorithm,
NaiveBayesModelSummary naiveBayesModelSummary) {
NaiveBayesModel naiveBayesModel
= NaiveBayes.train(parsedData.rdd(), naiveBayesDetectionAlgorithm.getLambda());
naiveBayesModelSummary.setNaiveBayesDetectionAlgorithm(naiveBayesDetectionAlgorithm);
return naiveBayesModel;
}
开发者ID:shlee89,项目名称:athena,代码行数:10,代码来源:NaiveBayesDistJob.java
示例9: generateModelWithPreprocessing
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
public NaiveBayesModel generateModelWithPreprocessing(JavaPairRDD<Object, BSONObject> mongoRDD,
AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
NaiveBayesDetectionAlgorithm naiveBayesDetectionAlgorithm,
Marking marking,
NaiveBayesModelSummary naiveBayesModelSummary) {
return generateModel(
rddPreProcessing(mongoRDD, athenaMLFeatureConfiguration, naiveBayesModelSummary,
marking),
naiveBayesDetectionAlgorithm, naiveBayesModelSummary
);
}
开发者ID:shlee89,项目名称:athena,代码行数:13,代码来源:NaiveBayesDistJob.java
示例10: setNaiveBayesModel
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
public void setNaiveBayesModel(NaiveBayesModel naiveBayesModel) {
this.naiveBayesModel = naiveBayesModel;
}
开发者ID:shlee89,项目名称:athena,代码行数:4,代码来源:NaiveBayesDetectionModel.java
示例11: validate
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
public void validate(JavaPairRDD<Object, BSONObject> mongoRDD,
AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
NaiveBayesDetectionModel naiveBayesDetectionModel,
NaiveBayesValidationSummary naiveBayesValidationSummary) {
List<AthenaFeatureField> listOfTargetFeatures = athenaMLFeatureConfiguration.getListOfTargetFeatures();
Map<AthenaFeatureField, Integer> weight = athenaMLFeatureConfiguration.getWeight();
Marking marking = naiveBayesDetectionModel.getMarking();
NaiveBayesModel model = (NaiveBayesModel) naiveBayesDetectionModel.getDetectionModel();
Normalizer normalizer = new Normalizer();
int numberOfTargetValue = listOfTargetFeatures.size();
mongoRDD.foreach(new VoidFunction<Tuple2<Object, BSONObject>>() {
public void call(Tuple2<Object, BSONObject> t) throws UnknownHostException {
long start2 = System.nanoTime(); // <-- start
BSONObject feature = (BSONObject) t._2().get(AthenaFeatureField.FEATURE);
BSONObject idx = (BSONObject) t._2();
int originLabel = marking.checkClassificationMarkingElements(idx,feature);
double[] values = new double[numberOfTargetValue];
for (int j = 0; j < numberOfTargetValue; j++) {
values[j] = 0;
if (feature.containsField(listOfTargetFeatures.get(j).getValue())) {
Object obj = feature.get(listOfTargetFeatures.get(j).getValue());
if (obj instanceof Long) {
values[j] = (Long) obj;
} else if (obj instanceof Double) {
values[j] = (Double) obj;
} else if (obj instanceof Boolean) {
values[j] = (Boolean) obj ? 1 : 0;
} else {
return;
}
//check weight
if (weight.containsKey(listOfTargetFeatures.get(j))) {
values[j] *= weight.get(listOfTargetFeatures.get(j));
}
//check absolute
if (athenaMLFeatureConfiguration.isAbsolute()){
values[j] = Math.abs(values[j]);
}
}
}
Vector normedForVal;
if (athenaMLFeatureConfiguration.isNormalization()) {
normedForVal = normalizer.transform(Vectors.dense(values));
} else {
normedForVal = Vectors.dense(values);
}
LabeledPoint p = new LabeledPoint(originLabel,normedForVal);
int validatedLabel = (int) model.predict(p.features());
naiveBayesValidationSummary.updateSummary(validatedLabel,idx,feature);
long end2 = System.nanoTime();
long result2 = end2 - start2;
naiveBayesValidationSummary.addTotalNanoSeconds(result2);
}
});
naiveBayesValidationSummary.getAverageNanoSeconds();
naiveBayesValidationSummary.setNaiveBayesDetectionAlgorithm((NaiveBayesDetectionAlgorithm) naiveBayesDetectionModel.getDetectionAlgorithm());
}
开发者ID:shlee89,项目名称:athena,代码行数:69,代码来源:NaiveBayesDistJob.java
示例12: generateNaiveBayesAthenaDetectionModel
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
public NaiveBayesDetectionModel generateNaiveBayesAthenaDetectionModel(JavaSparkContext sc,
FeatureConstraint featureConstraint,
AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
DetectionAlgorithm detectionAlgorithm,
Indexing indexing,
Marking marking) {
NaiveBayesModelSummary naiveBayesModelSummary = new NaiveBayesModelSummary(
sc.sc(), indexing, marking);
long start = System.nanoTime(); // <-- start
NaiveBayesDetectionAlgorithm naiveBayesDetectionAlgorithm = (NaiveBayesDetectionAlgorithm) detectionAlgorithm;
NaiveBayesDetectionModel naiveBayesDetectionModel = new NaiveBayesDetectionModel();
naiveBayesDetectionModel.setNaiveBayesDetectionAlgorithm(naiveBayesDetectionAlgorithm);
naiveBayesModelSummary.setNaiveBayesDetectionAlgorithm(naiveBayesDetectionAlgorithm);
naiveBayesDetectionModel.setFeatureConstraint(featureConstraint);
naiveBayesDetectionModel.setAthenaMLFeatureConfiguration(athenaMLFeatureConfiguration);
naiveBayesDetectionModel.setIndexing(indexing);
naiveBayesDetectionModel.setMarking(marking);
JavaPairRDD<Object, BSONObject> mongoRDD;
mongoRDD = sc.newAPIHadoopRDD(
mongodbConfig, // Configuration
MongoInputFormat.class, // InputFormat: read from a live cluster.
Object.class, // Key class
BSONObject.class // Value class
);
NaiveBayesDistJob naiveBayesDistJob = new NaiveBayesDistJob();
NaiveBayesModel naiveBayesModel = naiveBayesDistJob.generateModelWithPreprocessing(mongoRDD,
athenaMLFeatureConfiguration, naiveBayesDetectionAlgorithm, marking, naiveBayesModelSummary);
naiveBayesDetectionModel.setNaiveBayesModel(naiveBayesModel);
long end = System.nanoTime(); // <-- start
long time = end - start;
naiveBayesModelSummary.setTotalLearningTime(time);
naiveBayesDetectionModel.setClassificationModelSummary(naiveBayesModelSummary);
return naiveBayesDetectionModel;
}
开发者ID:shlee89,项目名称:athena,代码行数:45,代码来源:MachineLearningManagerImpl.java
示例13: buildNaiveBayesModel
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
/**
* This method builds a naive bayes model
*
* @param sparkContext JavaSparkContext initialized with the application
* @param modelID Model ID
* @param trainingData Training data as a JavaRDD of LabeledPoints
* @param testingData Testing data as a JavaRDD of LabeledPoints
* @param workflow Machine learning workflow
* @param mlModel Deployable machine learning model
* @throws MLModelBuilderException
*/
private ModelSummary buildNaiveBayesModel(JavaSparkContext sparkContext, long modelID,
JavaRDD<LabeledPoint> trainingData, JavaRDD<LabeledPoint> testingData, Workflow workflow, MLModel mlModel,
SortedMap<Integer, String> includedFeatures) throws MLModelBuilderException {
try {
Map<String, String> hyperParameters = workflow.getHyperParameters();
NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
NaiveBayesModel naiveBayesModel = naiveBayesClassifier.train(trainingData,
Double.parseDouble(hyperParameters.get(MLConstants.LAMBDA)));
// remove from cache
trainingData.unpersist();
// add test data to cache
testingData.cache();
JavaPairRDD<Double, Double> predictionsAndLabels = naiveBayesClassifier.test(naiveBayesModel, testingData)
.cache();
ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary = SparkModelUtils
.getClassClassificationModelSummary(sparkContext, testingData, predictionsAndLabels);
// remove from cache
testingData.unpersist();
mlModel.setModel(new MLClassificationModel(naiveBayesModel));
classClassificationAndRegressionModelSummary.setFeatures(includedFeatures.values().toArray(new String[0]));
classClassificationAndRegressionModelSummary.setAlgorithm(SUPERVISED_ALGORITHM.NAIVE_BAYES.toString());
MulticlassMetrics multiclassMetrics = getMulticlassMetrics(sparkContext, predictionsAndLabels);
predictionsAndLabels.unpersist();
classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(getMulticlassConfusionMatrix(
multiclassMetrics, mlModel));
Double modelAccuracy = getModelAccuracy(multiclassMetrics);
classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());
return classClassificationAndRegressionModelSummary;
} catch (Exception e) {
throw new MLModelBuilderException("An error occurred while building naive bayes model: " + e.getMessage(),
e);
}
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:55,代码来源:SupervisedSparkModelBuilder.java
示例14: loadModel
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
@Override
protected NaiveBayesModel loadModel(SparkContext sc, String modelPath) {
return NaiveBayesModel.load(sc, modelPath);
}
开发者ID:IBMStreams,项目名称:streamsx.sparkMLLib,代码行数:5,代码来源:SparkNaiveBayes.java
示例15: train
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
/**
* This method trains a naive bayes model
*
* @param trainingData Training dataset as a JavaRDD of labeled points
* @param lambda Lambda parameter
* @return Naive bayes model
*/
public NaiveBayesModel train(JavaRDD<LabeledPoint> trainingData, double lambda) {
return NaiveBayes.train(trainingData.rdd(), lambda);
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:11,代码来源:NaiveBayesClassifier.java
示例16: NaiveBayesClassificationModel
import org.apache.spark.mllib.classification.NaiveBayesModel; //导入依赖的package包/类
/**
* Instantiates a new naive bayes classification model.
*
* @param modelId the model id
* @param model the model
*/
public NaiveBayesClassificationModel(String modelId, NaiveBayesModel model) {
super(modelId, model);
}
开发者ID:apache,项目名称:lens,代码行数:10,代码来源:NaiveBayesClassificationModel.java
注:本文中的org.apache.spark.mllib.classification.NaiveBayesModel类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论