• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

机器学习实战 决策树 python3实现 R语言实现(2)

原作者: [db:作者] 来自: [db:来源] 收藏 邀请


CART算法实现

1.python3

本文将构建两种树,一种回归树(regression tree),其每个节点包含单个值;第二种是模型树(model tree),其每个叶节点包含一个线性方程。

createTree()伪代码:

找到最佳的待切分特征:

如果该节点不能再分,将该节点存为叶节点

执行二元切分

在右子树调用createTree()方法

在左子树调用createTree()方法

 

from numpy import *

def loadDataSet(fileName):      
    dataMat = []                
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float,curLine) #将每行映射成浮点数
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):

#dataSet是数据集合,feature是待切分的特征,value是该特征的某个值
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]  
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value),:][0]
    return mat0,mat1

def regLeaf(dataSet):#returns the value used for each leaf
    return mean(dataSet[:,-1])

def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
    if feat == None: return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

对上述代码构造假数据进行测试:

import regTrees
from numpy import*
testMat = mat(eye(4))  #构造4行4列单位阵
print(testMat)
print(testMat[:,1])#testMat矩阵中的第2列
mat0,mat1 = regTrees.binSplitDataSet(testMat,1,0.5)  #调用binSplitDataSet函数切分该矩阵。把第二列数据中大于0.5的分为一类,小于等于0.5的分为第二类。
print(mat0)
print(mat1)

[[ 1.  0.  0.  0.]

 [ 0.  1.  0.  0.]

 [ 0.  0.  1.  0.]

 [ 0.  0.  0.  1.]]

******

[[ 0.]

 [ 1.]

 [ 0.]

 [ 0.]]

******

[[ 0.  1.  0.  0.]]

******

[[ 1.  0.  0.  0.]

 [ 0.  0.  1.  0.]

 [ 0.  0.  0.  1.]]

CART算法用于回归:

回归树假设叶节点是常数值,这种策略认为数据中的复杂关系可以用树结构来概括。

chooseBestSplit():

功能:给定某个误差计算方法,该函数会找到数据集熵最佳的二元切分方式。此外该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。所以该函数需要完成两个方面:

1.用最佳方式切分数据集

2.生成相应的叶节点

伪代码:

对每个特征:

对每个特征值:

将数据集切分成两份

计算切分的误差

如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差返回最佳切分的特征和阈值

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
   #为ops设定了tols和tolN这两个值。他们是用户指定的参数,用于控制函数的停止时机。其中变量tols是容许误差下降的值,tolN是切分的最小样本数。接下来函数会统计不同剩余特征值的数目,如果该数目为1,那么就不需要在进行切分而直接返回。然后函数计算了当数据集的大小和误差。该误差S将用于与新切分的误差进行对比,来检查切分是否能够降低误差。

    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m,n = shape(dataSet)

    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):

        for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):

            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)

            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS

    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex,bestValue

 

代码测试:

from numpy import *
import regTrees
import matplotlib.pyplot as plt
myDat = regTrees.loadDataSet("ex00.txt")
myMat = mat(myDat)
rtct = regTrees.createTree(myMat)
plt.plot(myMat[:,0],myMat[:,1],"ro")
plt.show()

     图5-1  基于CART算法构建回归树的简单数据集


from numpy import *
import regTrees
import matplotlib.pyplot as plt
myDat1 = regTrees.loadDataSet("ex0.txt")
myMat1 = mat(myDat1)
print(myMat1)
print("****************")
rtct = regTrees.createTree(myMat1)
print(rtct)
plt.plot(myMat1[:,1],myMat1[:,2],"ro")
plt.show()


5-2  用于测试回归树的分段常数数据集


该树包含5个叶节点。上述过程已经完成回归树的构建,但是需要某种措施来检查构建的过程是否恰当。

from numpy import *
import regTrees
import matplotlib.pyplot as plt
myDat2 = regTrees.loadDataSet("ex2.txt")
myMat2 = mat(myDat2)
retct = regTrees.createTree(myMat2)
print((retct))
plt.plot(myMat2[:,0],myMat2[:,1],"ro")
plt.show()


5-3  将图5-1的轴扩大100倍后的新数据集


5-3看上去和图5-1分长相思但是y轴的数量级是5-1的100倍这里的新树有很多叶节点,而5-1只有2个。这是因为停止条件tolS对误差的数量级十分敏感。

下面讨论后剪枝,即利用测试集来对树进行剪枝,由于不需要用户指定参数,后剪枝是一个更理想化的方法。

后剪枝方法需要将数据集分类成测试集和训练集。首先指定参数,使得构建出的树足够大,足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集判断将这些叶节点合并是否能降低测试误差,如果是的话就合并。

Prune()伪代码如下:

基于已有的树切分测试数据:

如果存在任一字集是一棵树,则在该子集递归剪枝过程

计算将当前两个叶子节点和并后的误差

计算不合并的误差

如果合并会降低误差的话,就将叶节点合并

 

def isTree(obj):
    return (type(obj).__name__=='dict') #判断为字典类型返回true
#返回树的平均值
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0


#树的后剪枝
def prune(tree, testData):#待剪枝的树和剪枝所需的测试数据
    if shape(testData)[0] == 0: return getMean(tree)  # 确认数据集非空
    #假设发生过拟合,采用测试数据对树进行剪枝
    if (isTree(tree['right']) or isTree(tree['left'])): #左右子树非空
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
    #剪枝后判断是否还是有子树
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        #判断是否merge
        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \
                       sum(power(rSet[:, -1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right']) / 2.0
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))
        #如果合并后误差变小
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

代码测试:

from numpy import *
import regTrees
import matplotlib.pyplot as plt
myDat3 = regTrees.loadDataSet("exp2.txt")
myMat3 = mat(myDat3)
retct = regTrees.createTree(myMat3,ops = (0,1))
rtp = regTrees.prune(retct,myMat3)
print(retct)
print(rtp)
plt.plot(myMat3[:,0],myMat3[:,1],"ro")
plt.show()

5-4  用来测试某行书构建函数的分段线性数据

模型树的叶节点生成函数

#模型树
def linearSolve(dataSet):   #将数据集格式化为X Y
    m,n = shape(dataSet)
    X = mat(ones((m,n))); Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0: #X Y用于简单线性回归,需要判断矩阵可逆
        raise NameError('This matrix is singular, cannot do inverse,\n\
        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y

def modelLeaf(dataSet):#不需要切分时生成模型树叶节点
    ws,X,Y = linearSolve(dataSet)
    return ws #返回回归系数

def modelErr(dataSet):#用来计算误差找到最佳切分
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat,2))

{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[  1.69855694e-03],

        [  1.19647739e+01]]), 'right': matrix([[ 3.46877936],

        [ 1.18521743]])}

该模型以0.285477为界创建了两个模型,而5-4数据实际在0.3处分段。生成的两个线性模型分别是y= 3.468+1.1852和y=0.0016985+11.96477x,与用于生成该模型的真实模型非常接近。该数据实际啊是有模型y=3.5+1.0x和y=0+12x再加上高斯噪声生成的。

2.R语言实现

#查看所有连续变量的相关性,所有分类变量的卡方值

> idx.num <- which(sapply(algae,is.numeric))

> idx.num

mxPH mnO2   Cl  NO3  NH4 oPO4  PO4 Chla   a1   a2   a3   a4   a5   a6   a7

   4    5    6    7    8    9   10   11   12   13   14   15   16   17   18

> correlation <- cor(algae$a1,algae[,idx.num],use = "pairwise.complete.obs")

> correlation

           mxPH      mnO2         Cl        NO3       NH4       oPO4

[1,] -0.2651354 0.2873732 -0.3711709 -0.2412111 -0.132656 -0.4173576

            PO4       Chla a1         a2         a3          a4         a5

[1,] -0.4864228 -0.2779866  1 -0.2937678 -0.1465666 -0.03795656 -0.2915492

             a6         a7

[1,] -0.2734283 -0.2129063

> correlation <- abs(correlation)

> correlation <- correlation[,order(correlation,decreasing = T)]

> correlation

        a1        PO4       oPO4         Cl         a2         a5       mnO2

1.00000000 0.48642276 0.41735761 0.37117086 0.29376781 0.29154923 0.28737317

      Chla         a6       mxPH        NO3         a7         a3        NH4

0.27798661 0.27342831 0.26513541 0.24121109 0.21290633 0.14656656 0.13265601

        a4

0.03795656

#查看分类变量的卡方值

> idx.factor <- which(sapply(algae,is.factor))

> idx.factor

season   size  speed

     1      2      3

> class(idx.factor)

[1] "integer"

> algae[,idx.factor]

    season   size  speed

1   winter  small medium

2   spring  small medium

3   autumn  small medium

4   spring  small medium

5   autumn  small medium

6   winter  small   high

7   summer  small   high

8   autumn  small   high

9   winter  small medium

10  winter  small   high

……

> t1 <- table(algae$season,algae$size)

> t1

        

         large medium small

  autumn    11     16    13

  spring    12     21    20

  summer    10     21    14

  winter    12     26    24

> chisq.test(t1)

 

Pearson's Chi-squared test

 

data:  t1

X-squared = 1.662, df = 6, p-value = 0.948

 







鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
1. R语言运行效率分析_小结(4)发布时间:2022-07-18
下一篇:
R语言绘图--交互图发布时间:2022-07-18
热门推荐
热门话题
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap