• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

机器学习实战 决策树 python3实现 R语言实现(1)

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

二话不说,上代码:

1.python3实现

Python实现:ID3算法

1.决策树构造trees.py

from math import log
import operator
#简单的鉴定函数
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers']

    return dataSet, labels
#计算给定的数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)
    return shannonEnt
 #按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
 #选择最好的数据划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)     
        infoGain = baseEntropy - newEntropy
        if (infoGain >= bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
#字典对象存储了classList中每个类标签出现的频率,最后利用operator操作键值排序,返回出现次数最多的分类名称。
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
#创建树的函数代码
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]

#类别完全相同时停止继续划分
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)

#遍历完所有特征值时返回出现次数最多的类别
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    subLabels = labels[:]
    del(subLabels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)

#得到列表包含的所有属性
    for value in uniqueVals:

        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)

    return myTree                            
    
def classify(inputTree,featLabels,testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict):
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel
#使用pickle模块存储决策树
def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'wb')
    pickle.dump(inputTree,fw)
    fw.close()
    
def grabTree(filename):
    import pickle
    fr = open(filename,"rb")
    return pickle.load(fr)
    

 

2.绘制树型图treePlotter.py

import matplotlib.pyplot as plt
#使用文本注解绘制树节点

#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
#获取叶节点的数目和树的层数
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]

#这里书中有错误。用list转变为列表后才能用【】提取键值
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr =list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth
#绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
#在父子节点间填充文本信息    
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
#计算宽与高
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)  
#标记子节点的属性
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  
#减少y偏移
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#主函数
def createPlot(inTree):
    fig = plt.figure(1, facecolor= 'white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

3.测试:使用决策树执行分类。预测隐形眼镜类型

步骤:

1. 收集数据;《机器学习实战》提供的文本文件

2. 准备数据:解析tab键分隔的数据行

3. 分析数据:快速检查数据,确保正确的解析数据内容,使用createPlot()函数绘制最终的树形图。

4. 训练算法:使用createTree()函数

5. 测试算法:编写测试函数验证决策树可以正确分类给定的数据实例。

6. 使用算法:存储树的数据结构,一边下次使用时无需重新构造。


2.R语言实现

方法1:

library(party)  #用于实现决策树算法

library(sampling)  #用于实现数据分层随机抽样,构造训练集和测试集

data(iris)

str(iris)

'data.frame':150 obs. of  5 variables:

 $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...

 $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...

 $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...

 $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...

 $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...

 

dim(iris)

[1] 150   5

 

sub_train = strata(iris,stratanames = "Species",size = rep(35,3),method = "srswor")

 

#strata:分层抽样在包sampling中,

Strata(data,stratanames = NULL,size,method=c(srswor,””srswr,poisson,systematic),pik,description = FALSE)

Data:待抽样数据

Stratanames:进行分层所依变量名

Size:各层中要抽出的观测样本数

Method:选择四种抽样方法,分别为无放回,有放回,泊松,系统抽样,默认srswor

Pik:设置各层中样本的抽样概率

Description:选择是否输出含有各层基本信息的结果。

data_train = iris[sub_train$ID_unit,]

data_train = iris[-sub_train$ID_unit,]

 

iris_tree = ctree(Species~.,data = data_train)  

#ctree:条件推理树是一种比较常用的基于树的分类算法,与传统决策树(rpart)不同之处在于条件推理树是选择分类变量时依据的是显著性测量的结果,而不是采用信息最大化法。Rpart采用的是基尼系数。

 

print(iris_tree)

 Conditional inference tree with 3 terminal nodes

 

Response:  Species

Inputs:  Sepal.Length, Sepal.Width, Petal.Length, Petal.Width

Number of observations:  45

 

1) Petal.Length <= 1.7; criterion = 1, statistic = 40.933

  2)*  weights = 15

1) Petal.Length > 1.7

  3) Petal.Width <= 1.6; criterion = 1, statistic = 19.182

    4)*  weights = 15

  3) Petal.Width > 1.6

    5)*  weights = 15

 

plot(iris_tree)

plot(iriis_tree,type = "simple")

方法2:RWeka实现C4.5算法该过程需要安装java。(或者不用?把变量定义为因子,还没试过。)

library(RWeka)

library(grid)

library(mvtnorm)

library(modeltools)

library(stats4)

library(strucchange)

library(zoo)

library(partykit)

library(rJava)

data(iris)

str(iris)

 

m1 <- J48(Species~.,data=iris)

m1

J48 pruned tree

------------------

 

Petal.Width <= 0.6: setosa (50.0)

Petal.Width > 0.6

|   Petal.Width <= 1.7

|   |   Petal.Length <= 4.9: versicolor (48.0/1.0)

|   |   Petal.Length > 4.9

|   |   |   Petal.Width <= 1.5: virginica (3.0)

|   |   |   Petal.Width > 1.5: versicolor (3.0/1.0)

|   Petal.Width > 1.7: virginica (46.0/1.0)

 

Number of Leaves  : 5

 

Size of the tree : 9

 

table(iris$Species,predict(m1))

       setosa versicolor virginica

  setosa         50          0         0

  versicolor      0         49         1

  virginica       0          2        48

plot(m1)






鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
R语言基础发布时间:2022-07-18
下一篇:
(转载)R语言ARIMA时间序利基于R语言的时间序列分析预测发布时间:2022-07-18
热门推荐
热门话题
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap