大数据分析 - 决策树

  • 简述

    决策树是一种用于分类或回归等监督学习问题的算法。决策树或分类树是其中每个内部(非叶)节点都用输入特征标记的树。来自标记有特征的节点的弧被标记为特征的每个可能值。树的每个叶子都标有一个类或类的概率分布。
    可以通过基于属性值测试将源集拆分为子集来“学习”树。这个过程以递归方式在每个派生的子集上重复,称为recursive partitioning. 当节点处的子集具有目标变量的所有相同值时,或者当拆分不再为预测增加值时,递归完成。这种自上而下归纳决策树的过程是贪心算法的一个例子,也是学习决策树最常用的策略。
    数据挖掘中使用的决策树有两种主要类型 -
    • Classification tree− 当响应是一个名义变量时,例如电子邮件是否为垃圾邮件。
    • Regression tree− 当预测结果可以被认为是一个实数时(例如工人的薪水)。
    决策树是一种简单的方法,因此存在一些问题。其中一个问题是决策树产生的结果模型的高方差。为了缓解这个问题,开发了决策树的集成方法。目前广泛使用两组集成方法 -
    • Bagging decision trees− 这些树用于构建多个决策树,方法是通过重复对训练数据进行替换重采样,并对树进行投票以获得共识预测。这种算法被称为随机森林。
    • Boosting decision trees− 梯度提升结合了弱学习器;在这种情况下,决策树以迭代的方式变成一个单一的强学习器。它为数据拟合一棵弱树,并迭代地拟合弱学习器,以纠正先前模型的错误。
    
    # Install the party package
    # install.packages('party') 
    library(party) 
    library(ggplot2)  
    head(diamonds) 
    # We will predict the cut of diamonds using the features available in the 
    diamonds dataset. 
    ct = ctree(cut ~ ., data = diamonds) 
    # plot(ct, main="Conditional Inference Tree") 
    # Example output 
    # Response:  cut  
    # Inputs:  carat, color, clarity, depth, table, price, x, y, z  
    # Number of observations:  53940  
    #  
    # 1) table <= 57; criterion = 1, statistic = 10131.878 
    #   2) depth <= 63; criterion = 1, statistic = 8377.279 
    #     3) table <= 56.4; criterion = 1, statistic = 226.423 
    #       4) z <= 2.64; criterion = 1, statistic = 70.393 
    #         5) clarity <= VS1; criterion = 0.989, statistic = 10.48 
    #           6) color <= E; criterion = 0.997, statistic = 12.829 
    #             7)*  weights = 82  
    #           6) color > E  
    #Table of prediction errors 
    table(predict(ct), diamonds$cut) 
    #            Fair  Good Very Good Premium Ideal 
    # Fair       1388   171        17       0    14 
    # Good        102  2912       499      26    27 
    # Very Good    54   998      3334     249   355 
    # Premium      44   711      5054   11915  1167 
    # Ideal        22   114      3178    1601 19988 
    # Estimated class probabilities 
    probs = predict(ct, newdata = diamonds, type = "prob") 
    probs = do.call(rbind, probs) 
    head(probs)