• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

matlab svmtrain和svmclassify函数使用示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

监督式学习(Supervised Learning)常用算法包括:线性回归(Linear Regression)、逻辑回归(Logistic Regression)、神经网络(Neural Network)以及支持向量机(Support Vector Machine,SVM)等。支持向量机与逻辑回归算法类似,都是解决二分类或多分类问题,但是SVM在非线性回归预测方面具有更优秀的分类效果,所以SVM又被称为最大间距分类器。

本文不对支持向量机的原理进行详细解释,直接运用matlab自带的工具箱函数svmtrain、svmclassify解决实际的二分类问题。

导入数据:

 

 
  1. clear; close all; clc;

  2. %% ================ load fisheriris.mat ================

  3. load fisheriris.mat

1、对于线性分类问题,我们选取线性核函数,原始数据包括训练数据和测试数据两部分。

 

 
  1. data = meas(51:end,3:4); % column 3,column 4作为特征值

  2. group = species(51:end); % 类别

  3. idx = randperm(size(data,1));

  4. N = length(idx);

  5.  
  6. % SVM train

  7. T = floor(N*0.9); % 90组数据作为训练数据

  8. xdata = data(idx(1:T),:);

  9. xgroup = group(idx(1:T));

  10. svmStr = svmtrain(xdata,xgroup,\'Showplot\',true);


训练过程得到结构体svmStr,对测试数据进行预测

 

 
  1. % SVM predict

  2. P = floor(N*0.1); % 10组预测数据

  3. ydata = data(idx(T+1:end),:);

  4. ygroup = group(idx(T+1:end));

  5. pgroup = svmclassify(svmStr,ydata,\'Showplot\',true); % svm预测

  6. hold on;

  7. plot(ydata(:,1),ydata(:,2),\'bs\',\'Markersize\',12);

  8. accuracy1 = sum(strcmp(pgroup,ygroup))/P*100; % 预测准确性

  9. hold off;


程序运行结果如下:

图中,方块*号表示测试数据的预测结果,accuracy1结果为90%(上下浮动)。

 

2、对于非线性分类问题,我们选取高斯核函数RBF,原始数据包括训练数据和测试数据两部分。

训练过程前,导入原始数据:

 

 
  1. data = meas(51:end,1:2); % column 1,column 2作为特征值

  2. group = species(51:end); % 类别

  3. idx = randperm(size(data,1));

  4. N = length(idx);

  5.  
  6. % SVM train

  7. T = floor(N*0.9); % 90组数据作为训练数据

  8. xdata = data(idx(1:T),:);

  9. xgroup = group(idx(1:T));


对于高斯核函数,有两个参数对SVM的分类效果有着重要的影响:一个是sigma;另一个是C。

首先讨论sigma的影响,sigma反映了RBF函数从最大值点向周围函数值下降的速度,sigma越大,下降速度越慢,对应RBF函数越平缓;sigma越小,下降速度越快,对应RBF函数越陡峭。对于不同的sigma,程序代码:

 

 
  1. % different sigma

  2. figure;

  3. sigma = 0.5;

  4. svmStr = svmtrain(xdata,xgroup,\'kernel_function\',\'rbf\',\'rbf_sigma\',...

  5. sigma,\'showplot\',true);

  6. title(\'sigma = 0.5\');

  7. figure;

  8. sigma = 1;

  9. svmStr = svmtrain(xdata,xgroup,\'kernel_function\',\'rbf\',\'rbf_sigma\',...

  10. sigma,\'showplot\',true);

  11. title(\'sigma = 1\');

  12. figure;

  13. sigma = 3;

  14. svmStr = svmtrain(xdata,xgroup,\'kernel_function\',\'rbf\',\'rbf_sigma\',...

  15. sigma,\'showplot\',true);

  16. title(\'sigma = 3\');


分类平面分别如下:

 

 

 

从图中可以看出,sigma越小,分类曲线越复杂,事实也确实如此。因为sigma越小,RBF函数越陡峭,下降速度越大,预测过程容易发生过拟合问题,使分类模型对训练数据过分拟合,而对测试数据预测效果不佳。

然后讨论C的影响,程序代码如下:

 

 
  1. % different C

  2. figure;

  3. C = 1;

  4. svmStr = svmtrain(xdata,xgroup,\'kernel_function\',\'rbf\',\'boxconstraint\',...

  5. C,\'showplot\',true);

  6. title(\'C = 0.1\');

  7. figure;

  8. C = 8;

  9. svmStr = svmtrain(xdata,xgroup,\'kernel_function\',\'rbf\',\'boxconstraint\',...

  10. C,\'showplot\',true);

  11. title(\'C = 1\');

  12. figure;

  13. C = 64;

  14. svmStr = svmtrain(xdata,xgroup,\'kernel_function\',\'rbf\',\'boxconstraint\',...

  15. C,\'showplot\',true);

  16. title(\'C = 10\');


分类平面如下:

 

 

 

从图中可以发现,C越大,分类曲线越复杂,也就是说越容易发生过拟合,因为C对应逻辑回归的lambda的倒数。

若令sigma=1,C=1,则对测试数据的预测程序:

 

 
  1. % SVM predict

  2. P = floor(N*0.1); % 10组预测数据

  3. ydata = data(idx(T+1:end),:);

  4. ygroup = group(idx(T+1:end));

  5. % sigma = 1,C = 1,default

  6. figure;

  7. svmStr = svmtrain(xdata,xgroup,\'kernel_function\',\'rbf\',\'showplot\',true);

  8. pgroup = svmclassify(svmStr,ydata,\'Showplot\',true); % svm预测

  9. hold on;

  10. plot(ydata(:,1),ydata(:,2),\'bs\',\'Markersize\',12);

  11. accuracy2 = sum(strcmp(pgroup,ygroup))/P*100; % 预测准确性

  12. hold off;


程序运行结果如下:

图中,方块*号表示测试数据的预测结果,accuracy2结果为70%(上下浮动)。

分类效果不佳因为两个特征量的选择,可以选择更合适的特征量。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
matlab练习程序(卡尔曼滤波)发布时间:2022-07-18
下一篇:
DelphiADOQuery的速度优化转发布时间:2022-07-18
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap