请选择 进入手机版 | 继续访问电脑版
  • 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Spark机器学习库指南[Spark 1.3.1版]——决策树(decision trees)

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

下面是章节决策树的目录(其他内容参见全文目录)

  • 基础算法
    • 节点不纯度和信息增益(Node impurity and information gain)
    • 分裂候选集(Split candidates)
    • 停止规则(Stopping rule)
  • 使用建议
    • 问题规格参数(Problem specification parameters)
    • 停止条件(Stopping criteria)
    • 可调参数(Tunable parameters)
    • 缓存和检查点(Caching and checkpointing)
  • 扩展性
  • 示例
    • 分类(Classification)
    • 回归(Regression)

决策树以及决策树的集成是很受欢迎的机器学习算法,经常用于解决分类和回归任务。决策树因为下列原因被广泛使用:

  1. 模型容易解释。决策树模型是可阅读的,从根到叶节点,每一步都可以看出决策依据。
  2. 能够处理类别型特征(跟连续特征相对)。
  3. 能解决多分类问题
  4. 不需要特征尺度变换(特征值只参入比较而不参入数值计算)
  5. are able to capture non-linearities and feature interactions[还没想好这个该怎么翻译啊]

像随机森林和迭代决策树这样的集成树算法在分类和回归任务上也位居前列。

MLlib支持使用决策树做二分类、多分类和回归,既能使用连续特征又能使用类别特征。具体实现中按行对数据分区,允许对百万级的实例进行分布式训练。

集成树(随机森林和梯度迭代树)将会在下一节中介绍。

基础算法

决策树是一个贪心算法,它递归地对特征空间做分裂处理(不一定是二分,单属性多个值的时候是多分)。决策树中每个底层(叶节点)分区对应一个相同的标签。每个分区都是通过从可能的分裂中贪心地选择一个最好的分裂,从而最大化当前节点的信息增益。也就是说,在每个树节点上选择的分裂是从集合argmaxsIG(D,s)中选出来的,其中IG(D, s)是在数据集D上做s切分的信息增益。

节点不纯度和信息增益(Node impurity and information gain)

节点不纯度是用来衡量通过某节点属性对样本做分裂之后每个部分的标签的一致性的,也就是看每部分中标签是不是属于同一个类型(labe)。MLlib当前的实现为分类提供了两种不纯度衡量方法(GINI不纯度和熵),为回归提供了一种不纯度衡量方法(方差)。

Impurity Task Formula Description
Gini 分类 fi是某个分区内第i个标签的频率,C是该分区中的类别总数数。GINI不纯度度量的是类型被分错的可能性(类型的概率乘以分错的概率
分类 i是某个分区内第i个标签的频率,C是该分区中的类别总数数
方差 分类 yi是某个实例的标签,N是实例的总数, μ是所有实例的均值:

基于信息增益的方法:信息增益是父节点的不纯度和两个子节点的不纯度之差。假设一个分裂s将数据集D(大小为N)分为两个数据集Dleft(大小为Nleft)和Dright(大小为Nright),那么信息增益可以表示为:

其中Impurity(D)、Impurity(Dleft)、Impurity(Dright)是相应系统的不纯度,见上文表格中的三种不纯度公式。注意这里的信息增益跟信息论的专门概率信息增益是不同的,它是父子节点的不纯度之差,而不单指熵差。

注:

使用信息增益的决策树算法实现有ID3、C4.5:ID3使用的是信息增益,它在分裂时可能倾向于属性值多的节点;C4.5是ID3的改进版,它使用的是信息增益率,另外还基于信息增益对连续型特征做了离散化处理。

使用GINI不纯度的决策树算法实现有CART。

 

分裂候选集(Split candidates)

连续特征(Continuous features)

在单机版实现中,数据集一般较小,对每个连续特征分裂出的候选集通常是去重之后的所有特征值。为了更快的进行树计算,某些实现中会对特征值做排序并使用有序的去重值集作为候选集和。

但在大规模分布式数据集上,特征值排序代价是非常昂贵的。决策树在大数据上的实现,一般会计算出一个近似的分裂候选集,具体做法是在抽样的部分数据集上做分位点(quantile)计算[分位点是将有序数据集分成N等分的分界点,例如:2分位点是中值或者说中位数],这个其实就是连续特征离散化。有序分裂产生了“区段”(bins),区段的最大数量可以通过maxBins参数设置。

注意:区段的数量不可能超过当前的实例数量。如果不满足这个条件,算法会自动调整区段参数值(maxBins的最大值是32)。

 

类别型特征(Categorical features)

对于类别型特征,如果该特征有M个可能的值(类别),那么我们可以得到高达2M-1-1个候选分裂集和。在二分类(0/1)和回归中,我们可以按照平均标签对类别特征值排序(参见《Elements of Statistical Machine Learning 》章节9.2.4),进而将分裂候选集减少到M-1。例如,在一个二分类问题中,一个类别型特征有三种类型A、B和C,对应的标签(二分类中label 1)分别占比0.2、0.6和0.4,那么该类别型特征值可以排序为A、C、B。两种分裂候选集是<A|C,B>和<A,C|B>,其中”|”表示分裂。

In multiclass classification, all 2M11 possible splits are used whenever possible. When 2M11 is greater than the maxBins parameter, we use a (heuristic) method similar to the method used for binary classification and regression. The M categorical feature values are ordered by impurity, and the resulting M1 split candidates are considered.(对这一段的理解不足,后续再翻译)

停止规则(Stopping rule)

递归树的构建在节点满足下列条件时停止:

  1. 节点深度等于maxDepth这个训练参数。
  2. 没有分裂候选集能产生比minInfoGain更大的信息增益。
  3. 没有分裂候选集产生的子节点都至少对应minInstancesPerNode个训练样本。

使用建议(Usage tips)

为了更好地使用决策树,在下文中我们会讨论各种参数的用法。下面列举的参数大致按重要程度排序。新手应该主要关注“问题规格参数”这一节以及“最大深度”这个参数。

问题规格参数(Problem specification parameters)

问题规格参数描述了要解决的问题和数据集。这些参数只需要指定即可不需要调优。

  • algo: Classification 或者 Regression
  • numClasses: 分类的类型数量Number of classes (只用于Classification)
  • categoricalFeaturesInfo: 指明哪些特征是类别型的以及每个类别型特征对应值(类别)的数量。通过map来指定,map的key是特征索引,value是特征值数量。不在这个map中的特征默认是连续型的。
    • 例如:Map(0 -> 2, 4->10)表示特征0有两个特征值(0和1),特征4有10个特征值{0,1,2,3,…,9}。注意特征索引是从0开始的,0和4表示第1和第5个特征。
    • 注意可以不指定参数categoricalFeaturesInfo。算法这个时候仍然会正常执行。但是类别型特征显示说明的话应该会训练出更好的模型。

停止条件(Stopping criteria)

这些参数决定了何时停止构建树(添加新节点)。当调整这些参数的时候,要谨慎使用测试集做校验防止过拟合。

  • maxDepth: 树的最大深度。越深的树表达能力越强(潜在允许更高的准确率),但是训练的代价也越大并更容易过拟合。
  • minInstancesPerNode: 如果一个节点需要分裂的话,它的每个子节点必须至少有minInstancesPerNode个训练样本。这个通常在随机森林中用到,因为随机森林比独立树有更大的训练深度。
  • minInfoGain: 如果一个节点需要分裂的话,必须最少有minInfoGain的信息增益。

可调参数(Tunable parameters)

这些参数可用于调优。调优时要在测试集上小心测试以免过拟合。

  • maxBins: 连续特征离散化时用到的最大区段(bins)数。
    • 增加maxBins的值,可以让算法考察跟多的分裂候选集,从而做耕细粒度的分裂。但是,这会增加计算量和通信开销。
    • 对于类别型特征,maxBins参数必须至少是特征值(类的数量M)的数量。
  • maxMemoryInMB:  存储统计信息的内存大小。
    • 默认值保守地设置为256MB,这个大小在绝大多数应用场景下够用。增加maxMemoryInMB(如果增加的内存可用的话)允许更少的数据遍历,可以提升训练速度。但是,这也可能降低汇报,因为当maxMemoryInMB增长时,每次体的带的通信开销也会成比例增长。
    • 实现上的细节:为了处理更好,决策树算法会收集一组需要分裂节点的统计信息(而不是一次一个节点)。一个组中能够处理的节点数量取决于内存需求(不同的特征差异大)。MaxMemoryInMB参数指定了每个worker上用于统计的内存限制(单位是MB)。
  • subsamplingRate: 用于学习决策树的训练样本比例。这个参数跟树的集成(随机森林,梯度提升树)最相关,用于从原始数据中抽取子样本。对于单个决策树,这个参数用途不大,因为训练样本的数量通常不是主要约束。
  • impurity: 用于选择候选分裂的不纯度度量标准。这个参数需要跟algo参数匹配。它的取值在上文表格中有讨论。

缓存和检查点(Caching and checkpointing)

MLlib 1.2添加了几个特性用来扩展到更大的树以及数的集成。当maxDepth设置得比较大,开启节点ID缓存和检查点就比较有用了。这些参数在numThrees设置得比较大的随机森林算法中也比较有用。

  • useNodeIdCache:当这个参数设置为true,算法会避免在每次迭代中将当前模型传给spark执行器(excutors)。
    • 这对深度大的树(加速计算)和大的随机森林(减少每次迭代的通信开销)比较有用。
    • 实现上的细节:默认情况下,算法向执行器传达当前模型的信息以便执行器匹配训练样本和树节点。如果这个设置开启,模型信息直接缓存而不需要传送。

节点ID缓存产生了一个RDD序列(每次迭代1个)。这个长序列会导致性能问题,但是为RDD设置检查点(checkpoitng)可以缓解这个问题。注意检查点只在useNodeIdCache开启时可用。

  • checkpointDir: 设置检查点的保存目录
  • checkpointInterval: 设置检查点的频率。过高会导致大量集群写操作。过低的话,如果执行器失败,RDD需要重新计算。

扩展性(Scaling)

在计算方面,扩展能力跟训练样本数、特征数量、maxBins参数有着近似线性的关系。在通信方面,扩展能力跟特征数量以及maxBins有着近似线性的关系。

The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input.

 

实现的算法既可以读取稀疏数据又可以读取密集型数据。但是,目前没有为稀疏输入做优化。

示例

分类(Classification)

下面的例子说明了怎样导入LIBSVM 数据文件,解析成RDD[LabeledPoint],然后使用决策树进行分类。GINI不纯度作为不纯度衡量标准并且树的最大深度设置为5。最后计算了测试错误率从而评估算法的准确性。

from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a DecisionTree model.
#  Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                     impurity='gini', maxDepth=5, maxBins=32)

# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "myModelPath")
sameModel = DecisionTreeModel.load(sc, "myModelPath")

 

回归(Regression)

下面的例子说明了如何导入LIBSVM 数据文件,解析为RDD[LabeledPoint],然后使用决策树执行回归。方差作为不存度衡量标准,树最大深度是5。最后计算了均方误差用来评估拟合度。

from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a DecisionTree model.
#  Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
                                    impurity='variance', maxDepth=5, maxBins=32)

# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count())
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression tree model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "myModelPath")
sameModel = DecisionTreeModel.load(sc, "myModelPath")

鲜花

握手

雷人

路过

鸡蛋
专题导读
热门推荐
热门话题
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap