CART算法实现
1.python3
本文将构建两种树,一种回归树(regression tree),其每个节点包含单个值;第二种是模型树(model tree),其每个叶节点包含一个线性方程。
createTree()伪代码:
找到最佳的待切分特征:
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树调用createTree()方法
在左子树调用createTree()方法
from numpy import *
def loadDataSet(fileName): dataMat = [] fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split('\t') fltLine = map(float,curLine) #将每行映射成浮点数 dataMat.append(fltLine) return dataMat
def binSplitDataSet(dataSet, feature, value):
#dataSet是数据集合,feature是待切分的特征,value是该特征的某个值 mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0] mat1 = dataSet[nonzero(dataSet[:,feature] <= value),:][0] return mat0,mat1
def regLeaf(dataSet):#returns the value used for each leaf return mean(dataSet[:,-1])
def regErr(dataSet): return var(dataSet[:,-1]) * shape(dataSet)[0]
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split if feat == None: return val #if the splitting hit a stop condition return val retTree = {} retTree['spInd'] = feat retTree['spVal'] = val lSet, rSet = binSplitDataSet(dataSet, feat, val) retTree['left'] = createTree(lSet, leafType, errType, ops) retTree['right'] = createTree(rSet, leafType, errType, ops) return retTree
对上述代码构造假数据进行测试:
import regTrees from numpy import* testMat = mat(eye(4)) #构造4行4列单位阵 print(testMat) print(testMat[:,1])#testMat矩阵中的第2列 mat0,mat1 = regTrees.binSplitDataSet(testMat,1,0.5) #调用binSplitDataSet函数切分该矩阵。把第二列数据中大于0.5的分为一类,小于等于0.5的分为第二类。 print(mat0) print(mat1)
[[ 1. 0. 0. 0.]
[ 0. 1. 0. 0.]
[ 0. 0. 1. 0.]
[ 0. 0. 0. 1.]]
******
[[ 0.]
[ 1.]
[ 0.]
[ 0.]]
******
[[ 0. 1. 0. 0.]]
******
[[ 1. 0. 0. 0.]
[ 0. 0. 1. 0.]
[ 0. 0. 0. 1.]]
将CART算法用于回归:
回归树假设叶节点是常数值,这种策略认为数据中的复杂关系可以用树结构来概括。
chooseBestSplit():
功能:给定某个误差计算方法,该函数会找到数据集熵最佳的二元切分方式。此外该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。所以该函数需要完成两个方面:
1.用最佳方式切分数据集
2.生成相应的叶节点
伪代码:
对每个特征:
对每个特征值:
将数据集切分成两份
计算切分的误差
如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差返回最佳切分的特征和阈值
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): tolS = ops[0]; tolN = ops[1] #为ops设定了tols和tolN这两个值。他们是用户指定的参数,用于控制函数的停止时机。其中变量tols是容许误差下降的值,tolN是切分的最小样本数。接下来函数会统计不同剩余特征值的数目,如果该数目为1,那么就不需要在进行切分而直接返回。然后函数计算了当数据集的大小和误差。该误差S将用于与新切分的误差进行对比,来检查切分是否能够降低误差。
if len(set(dataSet[:,-1].T.tolist()[0])) == 1: return None, leafType(dataSet) m,n = shape(dataSet)
S = errType(dataSet) bestS = inf; bestIndex = 0; bestValue = 0 for featIndex in range(n-1):
for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue newS = errType(mat0) + errType(mat1) if newS < bestS: bestIndex = featIndex bestValue = splitVal bestS = newS
if (S - bestS) < tolS: return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): return None, leafType(dataSet) return bestIndex,bestValue
代码测试:
from numpy import * import regTrees import matplotlib.pyplot as plt myDat = regTrees.loadDataSet("ex00.txt") myMat = mat(myDat) rtct = regTrees.createTree(myMat) plt.plot(myMat[:,0],myMat[:,1],"ro") plt.show()
|
图5-1 基于CART算法构建回归树的简单数据集
from numpy import * import regTrees import matplotlib.pyplot as plt myDat1 = regTrees.loadDataSet("ex0.txt") myMat1 = mat(myDat1) print(myMat1) print("****************") rtct = regTrees.createTree(myMat1) print(rtct) plt.plot(myMat1[:,1],myMat1[:,2],"ro") plt.show() |
图5-2 用于测试回归树的分段常数数据集
该树包含5个叶节点。上述过程已经完成回归树的构建,但是需要某种措施来检查构建的过程是否恰当。
from numpy import * import regTrees import matplotlib.pyplot as plt myDat2 = regTrees.loadDataSet("ex2.txt") myMat2 = mat(myDat2) retct = regTrees.createTree(myMat2) print((retct)) plt.plot(myMat2[:,0],myMat2[:,1],"ro") plt.show()
|
图5-3 将图5-1的轴扩大100倍后的新数据集
图5-3看上去和图5-1分长相思但是y轴的数量级是5-1的100倍这里的新树有很多叶节点,而5-1只有2个。这是因为停止条件tolS对误差的数量级十分敏感。
下面讨论后剪枝,即利用测试集来对树进行剪枝,由于不需要用户指定参数,后剪枝是一个更理想化的方法。
后剪枝方法需要将数据集分类成测试集和训练集。首先指定参数,使得构建出的树足够大,足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集判断将这些叶节点合并是否能降低测试误差,如果是的话就合并。
Prune()伪代码如下:
基于已有的树切分测试数据:
如果存在任一字集是一棵树,则在该子集递归剪枝过程
计算将当前两个叶子节点和并后的误差
计算不合并的误差
如果合并会降低误差的话,就将叶节点合并
def isTree(obj): return (type(obj).__name__=='dict') #判断为字典类型返回true #返回树的平均值 def getMean(tree): if isTree(tree['right']): tree['right'] = getMean(tree['right']) if isTree(tree['left']): tree['left'] = getMean(tree['left']) return (tree['left']+tree['right'])/2.0
#树的后剪枝 def prune(tree, testData):#待剪枝的树和剪枝所需的测试数据 if shape(testData)[0] == 0: return getMean(tree) # 确认数据集非空 #假设发生过拟合,采用测试数据对树进行剪枝 if (isTree(tree['right']) or isTree(tree['left'])): #左右子树非空 lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet) #剪枝后判断是否还是有子树 if not isTree(tree['left']) and not isTree(tree['right']): lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) #判断是否merge errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \ sum(power(rSet[:, -1] - tree['right'], 2)) treeMean = (tree['left'] + tree['right']) / 2.0 errorMerge = sum(power(testData[:, -1] - treeMean, 2)) #如果合并后误差变小 if errorMerge < errorNoMerge: print("merging") return treeMean else: return tree else: return tree
代码测试:
from numpy import * import regTrees import matplotlib.pyplot as plt myDat3 = regTrees.loadDataSet("exp2.txt") myMat3 = mat(myDat3) retct = regTrees.createTree(myMat3,ops = (0,1)) rtp = regTrees.prune(retct,myMat3) print(retct) print(rtp) plt.plot(myMat3[:,0],myMat3[:,1],"ro") plt.show()
|
图5-4 用来测试某行书构建函数的分段线性数据
模型树的叶节点生成函数
#模型树 def linearSolve(dataSet): #将数据集格式化为X Y m,n = shape(dataSet) X = mat(ones((m,n))); Y = mat(ones((m,1))) X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1] xTx = X.T*X if linalg.det(xTx) == 0.0: #X Y用于简单线性回归,需要判断矩阵可逆 raise NameError('This matrix is singular, cannot do inverse,\n\ try increasing the second value of ops') ws = xTx.I * (X.T * Y) return ws,X,Y
def modelLeaf(dataSet):#不需要切分时生成模型树叶节点 ws,X,Y = linearSolve(dataSet) return ws #返回回归系数
def modelErr(dataSet):#用来计算误差找到最佳切分 ws,X,Y = linearSolve(dataSet) yHat = X * ws return sum(power(Y - yHat,2))
{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[ 1.69855694e-03],
[ 1.19647739e+01]]), 'right': matrix([[ 3.46877936],
[ 1.18521743]])}
|
该模型以0.285477为界创建了两个模型,而5-4数据实际在0.3处分段。生成的两个线性模型分别是y= 3.468+1.1852和y=0.0016985+11.96477x,与用于生成该模型的真实模型非常接近。该数据实际啊是有模型y=3.5+1.0x和y=0+12x再加上高斯噪声生成的。
2.R语言实现
#查看所有连续变量的相关性,所有分类变量的卡方值
> idx.num <- which(sapply(algae,is.numeric))
> idx.num
mxPH mnO2 Cl NO3 NH4 oPO4 PO4 Chla a1 a2 a3 a4 a5 a6 a7
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
> correlation <- cor(algae$a1,algae[,idx.num],use = "pairwise.complete.obs")
> correlation
mxPH mnO2 Cl NO3 NH4 oPO4
[1,] -0.2651354 0.2873732 -0.3711709 -0.2412111 -0.132656 -0.4173576
PO4 Chla a1 a2 a3 a4 a5
[1,] -0.4864228 -0.2779866 1 -0.2937678 -0.1465666 -0.03795656 -0.2915492
a6 a7
[1,] -0.2734283 -0.2129063
> correlation <- abs(correlation)
> correlation <- correlation[,order(correlation,decreasing = T)]
> correlation
a1 PO4 oPO4 Cl a2 a5 mnO2
1.00000000 0.48642276 0.41735761 0.37117086 0.29376781 0.29154923 0.28737317
Chla a6 mxPH NO3 a7 a3 NH4
0.27798661 0.27342831 0.26513541 0.24121109 0.21290633 0.14656656 0.13265601
a4
0.03795656
#查看分类变量的卡方值
> idx.factor <- which(sapply(algae,is.factor))
> idx.factor
season size speed
1 2 3
> class(idx.factor)
[1] "integer"
> algae[,idx.factor]
season size speed
1 winter small medium
2 spring small medium
3 autumn small medium
4 spring small medium
5 autumn small medium
6 winter small high
7 summer small high
8 autumn small high
9 winter small medium
10 winter small high
……
> t1 <- table(algae$season,algae$size)
> t1
large medium small
autumn 11 16 13
spring 12 21 20
summer 10 21 14
winter 12 26 24
> chisq.test(t1)
Pearson's Chi-squared test
data: t1
X-squared = 1.662, df = 6, p-value = 0.948
|
|
|
请发表评论