本文主要参考了Python实现ID3算法,对浅谈决策树算法以及matlab实现ID3算法中的代码作了少许改动,用Map代替Struct从而实现中文字符的存储,并且可以有多个分叉。
处理数据为csv格式:
色泽,根蒂,敲声,纹理,脐部,触感,好瓜
青绿,蜷缩,浊响,清晰,凹陷,硬滑,是
乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,是
乌黑,蜷缩,浊响,清晰,凹陷,硬滑,是
青绿,蜷缩,沉闷,清晰,凹陷,硬滑,是
浅白,蜷缩,浊响,清晰,凹陷,硬滑,是
青绿,稍蜷,浊响,清晰,稍凹,软粘,是
乌黑,稍蜷,浊响,稍糊,稍凹,软粘,是
乌黑,稍蜷,浊响,清晰,稍凹,硬滑,是
乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,否
青绿,硬挺,清脆,清晰,平坦,软粘,否
浅白,硬挺,清脆,模糊,平坦,硬滑,否
浅白,蜷缩,浊响,模糊,平坦,软粘,否
青绿,稍蜷,浊响,稍糊,凹陷,硬滑,否
浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,否
乌黑,稍蜷,浊响,清晰,稍凹,软粘,否
浅白,蜷缩,浊响,模糊,平坦,硬滑,否
青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,否
使用时,需要先将数据以元胞或者字符列表的格式导入MATLAB,之后进行操作。
%数据预处理
%uiopen('E:\MATLAB\Machine_Learning\watermelon_2.csv',1)
size_data = size(watermelon2); %watermelon2为导入工作台的数据
dataset = watermelon2(2:size_data(1),:); %纯数据集
labels = watermelon2(1,1:size_data(2)-1); %属性标签
%生成决策树
mytree = ID3(dataset,labels);
[nodeids,nodevalue,branchvalue] = print_tree(mytree);
tree_plot(nodeids,nodevalue,branchvalue)
结果为:
函数代码文件
calShannonEnt.m
function shannonEnt = calShannonEnt(dataset)
% 计算信息熵
data_size = size(dataset);
labels = dataset(:,data_size(2));
numEntries = data_size(1);
labelCounts = containers.Map;
for i = 1:length(labels)
label = char(labels(i));
if labelCounts.isKey(label)
labelCounts(label) = labelCounts(label)+1;
else
labelCounts(label) = 1;
end
end
shannonEnt = 0.0;
for key = labelCounts.keys
key = char(key);
labelCounts(key);
prob = labelCounts(key) / numEntries;
shannonEnt = shannonEnt - prob*(log(prob)/log(2));
end
end
splitDataset.m
function subDataset = splitDataset(dataset,axis,value)
%划分数据集,取出该特征值为value的所有样本,并去除该属性
subDataset = {};
data_size = size(dataset);
for i=1:data_size(1)
data = dataset(i,:);
if string(data(axis)) == string(value)
subDataset = [subDataset;[data(1:axis-1) data(axis+1:length(data))]];
end
end
splitDataset.m
function bestFeature=chooseFeature(dataset,~)
% 选择最小熵的属性特征
baseEntropy = calShannonEnt(dataset);
data_size = size(dataset);
numFeatures = data_size(2) - 1
minEntropy = 2.0;
bestFeature = 0;
for i = 1:numFeatures
uniqueVals = unique(dataset(:,i));
newEntropy = 0.0;
for j=1:length(uniqueVals)
value = uniqueVals(j);
subDataset = splitDataset(dataset,i,value);
size_sub = size(subDataset);
prob = size_sub(1)/data_size(1);
%ShannonEnt = calShannonEnt(subDataset);
newEntropy = newEntropy + prob*calShannonEnt(subDataset);
end
%gain = baseEntropy- newEntropy;
if newEntropy<minEntropy
minEntropy = newEntropy;
bestFeature = i;
end
end
end
ID3.m
function myTree = ID3(dataset,labels)
% ID3算法构建决策树
% 输入参数:
% dataset:数据集
% labels:属性标签
% 输出参数:
% tree:构建的决策树
%%数据为空,则报错
if(isempty(dataset))
error('必须提供数据!')
end
size_data = size(dataset);
if (size_data(2)-1)~=length(labels)
error('属性数量与数据集不一致!')
end
classList = dataset(:,size_data(2));
%全为同一类,熵为0
if length(unique(classList))==1
myTree = char(classList(1));
return
end
%%属性集为空,应该用找最多数的那一类,这里取值……
if size_data(2) == 1
myTree = char(classList(1));
return
end
bestFeature = chooseFeature(dataset)
bestFeatureLabel = char(labels(bestFeature));
%mytree = struct(bestFeatureLabel,struct())
myTree = containers.Map;
leaf = containers.Map;
%myTree(char(bestFeatureLabel)) = leaf;
featValues = dataset(:,bestFeature);
uniqueVals = unique(featValues);
labels=[labels(1:bestFeature-1) labels(bestFeature+1:length(labels))]; %删除该属性
for i=1:length(uniqueVals)
subLabels = labels(:)';
value = char(uniqueVals(i));
subdata = splitDataset(dataset,bestFeature,value);
%mytree.(bestFeatureLabel).(value) = ID3(subdata,subLabels)
leaf(value) = ID3(subdata,subLabels);
%leaf_keys = leaf.keys();
myTree(char(bestFeatureLabel)) = leaf;
%mytree_keys = myTree.keys();
end
end
print_tree.m
function [nodeids_,nodevalue_,branchvalue_] = print_tree(tree)
% 层序遍历决策树,返回nodeids(节点关系),nodevalue(节点信息),branchvalue(枝干信息)
nodeids(1) = 0;
nodeid = 0;
nodevalue={};
branchvalue={};
queue = {tree} ;
while ~isempty(queue)
node = queue{1};
queue(1) = [];
if string(class(node))~="containers.Map" %叶节点
nodeid = nodeid+1;
nodevalue = [nodevalue,{node}];
elseif length(node.keys)==1 %节点
nodevalue = [nodevalue,node.keys];
node_info = node(char(node.keys));
nodeid = nodeid+1;
branchvalue = [branchvalue,node_info.keys];
for i=1:length(node_info.keys)
nodeids = [nodeids,nodeid];
%nodeids(nodeid+length(queue)+i) = nodeid;
end
% else
% nodeid = nodeid+1;
% branchvalue = [branchvalue,node.keys];
% for i=1:length(node.keys)
% %nodeids = [nodeids,nodeid];
% nodeids(nodeid+length(queue)+i) = nodeid;
% end
% %nodeid = nodeid+1;
end
if string(class(node))=="containers.Map"
%nodeid = nodeid+1;
keys = node.keys();
for i = 1:length(keys)
key = keys{i};
%nodeids(nodeid+length(queue)+i) = nodeid;
%nodevalue{1,nodeid} = key ;
queue=[queue,{node(key)}];
end
end
nodeids_=nodeids;
nodevalue_=nodevalue;
branchvalue_ = branchvalue;
end
tree_plot.m
function tree_plot(p,nodevalue,branchvalue)
% 参考treeplot
[x,y,h] = treelayout(p); %x:横坐标,y:纵坐标;h:树的深度
f = find(p~=0); %非0节点
pp = p(f); %非0值
X = [x(f); x(pp); NaN(size(f))];
Y = [y(f); y(pp); NaN(size(f))];
X = X(:);
Y = Y(:);
n = length(p);
if n<500
hold on;
plot(x,y,'ro',X,Y,'r-')
nodesize = length(x);
for i=1:nodesize
text(x(i)+0.01,y(i),nodevalue{1,i});
end
for i=2:nodesize
%text(x(i)-0.02,y(i)+0.01,branchvalue{1,i-1})
j = 3*i-5;
text((X(j)+X(j+1))/2-length(char(branchvalue{1,i-1}))/200,(Y(j)+Y(j+1))/2,branchvalue{1,i-1})
end
hold off
else
plot(X,Y,'r-');
end
xlabel(['height = ' int2str(h)]);
axis([0 1 0 1]);
end
因为是从Python代码转成Matlab的,加之对MATLAB不甚了解,中间有很多待优化的过程,甚至是某些纰漏,欢迎大家来拍砖。
其中比较难以理解的时nodeids的获取与构造,可以参考:https://blog.csdn.net/alpes2012/article/details/79504841
|
请发表评论