8 min to read
决策树原理及其python实现
原理
简介
决策树学习:根据数据的属性采用树状结构建立决策模型。决策树模型常常用来解决分类和回归问题。常见的算法包括 CART (Classification And Regression Tree)、ID3、C4.5、随机森林 (Random Forest) 等。比如ID3算法使用信息增益作为不纯度;C4.5算法使用信息增益率作为不纯度【此方法避免了ID3算法中的归纳偏置问题,因为ID3算法会偏向于选择类别较多的属性(形成分支较多会导致信息增益大)】;CART算法使用基尼系数作为不纯度。
算法 | 支持模型 | 数结构 | 特征选择 | 连续值 | 缺失值 | 剪枝 |
---|---|---|---|---|---|---|
ID3 | 分类 | 多叉树 | 信息增益 | 不支持 | 不支持 | 不支持 |
C4.5 | 分类 | 多叉树 | 信息增益率 | 支持 | 支持 | 支持 |
CART | 分类、回归 | 二叉树 | 基尼系数 | 支持 | 支持 | 支持 |
算法
ID3算法
利用熵来衡量信息集的无序程度。
假设某节点S包含s个样本,共有m个类别,分别对应$C_i,i\in{1,2,...,m}$,一个类别中包含样本数为$s_i$,此时确定节点中任意一个样本的类别所需的信息量(信息熵)为:
$$ I(s_1,s_2,...,s_m)=-\sum_{i=1}^{m}p_i*log(p_i) $$
其中$p_i$表示任一样本属于类别$C_i$的概率,即$p_i=\frac{s_i}{s}$。
假设特征A有v个不同取值,${a_1,a_2,...,a_v}$,那么利用特征A可以将该节点样本划分为v个子集。$S_j$包含了集合S中特征A取$a_j$值的样本集合,对应的样本数为$s_j$。假设$S_{ij}$为子集$S_j$属于类别$C_i$的样本集合,对应的样本数目为$s_{ij}$,那么在子集$S_j$中确定任一样本类别所需的信息熵为:
$$ I(s_{1j},s_{2j},...,s_{mj})=-\sum_{i=1}^{m}p_{ij}*log(p_{ij}) $$
单独在各子集进行样本类别确定所需信息熵的的加权平均:
$$ E(A)=\sum_{j=i}^v\frac{s_{1j}+s_{2j}+...+s_{mj}}{s}*I(s_{1j},s_{2j},...,s_{mj}) $$
这样利用特征A对当前节点的样本进行划分子集的信息增益为:
$$ Gain(A)=I(s_1,s_2,...,s_m)-E(A) $$
C4.5算法
ID3算法倾向于选择水平数量较多的变量,可能导致训练得到一个庞大且深度浅的树,C4.5算法对此进行了改进。
$$ GainRatio(A)=\frac{Gain(A)}{SplitInfo(A)} \ SplitInfo(A)=-\sum_{i=1}^{v}\frac{s_i}{s}log_2(\frac{s}{s_i}) \ SplitInfo(A)反映的是按照特征A对样本进行划分的广度和均匀度。 $$
CART分类树算法
使用基尼系数来代替信息增益比,基尼系数代表了模型的不纯度,基尼系数越小,则不纯度越低,特征越好。
假设有K个类别,第k个类别的概率为$p_k$,则基尼系数的表达式为:
$$ Gini(p)=\sum_{k=i}^Kp_k*(1-p_k)=1-=\sum_{k=i}^Kp_k^2 $$
CART回归树算法
实际生活中很多问题都是非线性的,不可能使用全局线性模型来拟合任何数据。一种可行的方法是将数据集切分成很多份易建模的数据,然后利用线性回归技术来建模。
对于回归模型,我们使用了常见的和方差的度量特征的各个划分点的优劣,CART回归树的度量目标是,对于任意划分特征A,对应的任意划分点s两边划分成的数据集D1和D2,求出使D1和D2各自集合的均方差最小,同时D1和D2的均方差之和最小所对应的特征和特征值划分点。表达式为:
$$ min_{A,s}[min_{c_1}\sum_{x_i\in D_1(A,s)}(y_i-c_1)^2+min_{c_2}\sum_{x_i\in D_2(A,s)}(y_i-c_2)^2] \ 其中,c_1为D1数据集的样本输入均值,c_2为D2数据集的样本输出均值。 $$
优缺点
优点:
- 分类规则清晰,结果容易理解;
- 计算量相对较小,实现速度快;
- 同时可以处理分类变量和数值变量(但是可能决策树对连续变量的划分并不合理,所以可以提前先离散化);
- 另外决策树不需要做变量筛选,它会自动筛选,适合处理高维度数据;
- 可以处理异常值、缺失值。
缺点:
- 不是很稳点,数据变化一点,树就会发生变化;
- 没有考虑变量之间相关性,每次筛选都只考虑一个变量(因此不需要归一化);
- 只能线性分割数据;
- 容易出现过拟合;
- 贪婪算法(可能找不到最好的树);
- 如果某些特征的样本比例过大,生成决策树容易偏向于这些特征。
树枝修剪
让树的完全长成,但这会出现过拟合问题。为了抑制这种情况的发生,需要给树的剪枝。
树的剪枝分为预剪枝和后剪枝,预剪枝,及早的停止树增长控制树的规模。
后剪枝在已生成过拟合决策树上进行剪枝,删除没有意义的组,可以得到简化版的剪枝决策树,包括REP(Reduced-Error Pruning),设定一定的误分类率,减掉对误分类率上升不超过阈值的多余枝;还有一种CCP,即给分裂准则加上惩罚项,此时树的层数越深,基尼系数的惩罚项会越大。
实战
# -*- coding: UTF-8 -*-
from math import log
def calcShannonEnt(dataSet):
numEntires = len(dataSet) #返回数据集的行数
labelCounts = {} #保存每个标签(Label)出现次数的字典
for featVec in dataSet: #对每组特征向量进行统计
currentLabel = featVec[-1] #提取标签(Label)信息
if currentLabel not in labelCounts.keys(): #如果标签(Label)没有放入统计次数的字典,添加进去
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 #Label计数
shannonEnt = 0.0 #经验熵(香农熵)
for key in labelCounts: #计算香农熵
prob = float(labelCounts[key]) / numEntires #选择该标签(Label)的概率
shannonEnt -= prob * log(prob, 2) #利用公式计算
return shannonEnt #返回经验熵(香农熵)
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'], #数据集
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labels = ['年龄', '有工作', '有自己的房子', '信贷情况'] #特征标签
return dataSet, labels #返回数据集和分类属性
def splitDataSet(dataSet, axis, value):
retDataSet = [] #创建返回的数据集列表
for featVec in dataSet: #遍历数据集
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] #去掉axis特征
reducedFeatVec.extend(featVec[axis+1:]) #将符合条件的添加到返回的数据集
retDataSet.append(reducedFeatVec)
return retDataSet #返回划分后的数据集
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #特征数量
baseEntropy = calcShannonEnt(dataSet) #计算数据集的香农熵
bestInfoGain = 0.0 #信息增益
bestFeature = -1 #最优特征的索引值
for i in range(numFeatures): #遍历所有特征
#获取dataSet的第i个所有特征
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) #创建set集合{},元素不可重复
newEntropy = 0.0 #经验条件熵
for value in uniqueVals: #计算信息增益
subDataSet = splitDataSet(dataSet, i, value) #subDataSet划分后的子集
prob = len(subDataSet) / float(len(dataSet)) #计算子集的概率
newEntropy += prob * calcShannonEnt(subDataSet) #根据公式计算经验条件熵
infoGain = baseEntropy - newEntropy #信息增益
# print("第%d个特征的增益为%.3f" % (i, infoGain)) #打印每个特征的信息增益
if (infoGain > bestInfoGain): #计算信息增益
bestInfoGain = infoGain #更新信息增益,找到最大的信息增益
bestFeature = i #记录信息增益最大的特征的索引值
return bestFeature #返回信息增益最大的特征的索引值
def majorityCnt(classList):
classCount = {}
for vote in classList: #统计classList中每个元素出现的次数
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True) #根据字典的值降序排序
return sortedClassCount[0][0] #返回classList中出现次数最多的元素
def createTree(dataSet, labels, featLabels):
classList = [example[-1] for example in dataSet] #取分类标签(是否放贷:yes or no)
if classList.count(classList[0]) == len(classList): #如果类别完全相同则停止继续划分
return classList[0]
if len(dataSet[0]) == 1: #遍历完所有特征时返回出现次数最多的类标签
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) #选择最优特征
bestFeatLabel = labels[bestFeat] #最优特征的标签
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel:{}} #根据最优特征的标签生成树
del(labels[bestFeat]) #删除已经使用特征标签
featValues = [example[bestFeat] for example in dataSet] #得到训练集中所有最优特征的属性值
uniqueVals = set(featValues) #去掉重复的属性值
for value in uniqueVals: #遍历特征,创建决策树。
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
return myTree
def classify(inputTree, featLabels, testVec):
firstStr = next(iter(inputTree)) #获取决策树结点
secondDict = inputTree[firstStr] #下一个字典
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else: classLabel = secondDict[key]
return classLabel
if __name__ == '__main__':
dataSet, labels = createDataSet()
featLabels = []
myTree = createTree(dataSet, labels, featLabels)
testVec = [0,1] #测试数据
result = classify(myTree, featLabels, testVec)
if result == 'yes':
print('放贷')
if result == 'no':
print('不放贷')
参考文献
[1]. 算法模型---决策树
[2]. 数据挖掘(6):决策树分类算法
Comments