前言

如果你以前没有接触过决策树,也不需要担心,它的概念非常简单。即使不知道它也可以通过简单的图形了解其中的工作原理,下图的流程图就是一个决策树,长方形代表判断模块(decision block),椭圆形代表终止模块(terminating block),表示这已经得出结论,可以终止运行。从判断模块引出的左右箭头称为分支(branch),它可以到达另一个判断模块或者终止模块。该流程图构造了一个假想的邮件分类系统,它首先检测发生邮件域名地址。如果地址为 myEmployer.com ,则将其放在分类 “无聊时需要阅读的邮件”,其他同理分类。

image-20230709002848184

K-近邻算法已经可以完成很多分类任务,但是它最大的缺点就是无法给出数据的内在含义决策树的主要优势就是在于数据形式非常容易理解。决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,这些机器根据数据集创建规则的过程,就是机器学习的过程。

前排提醒:在接下来的代码示例中,有看不懂的函数,可以尝试在下面的 函数相关说明 处查看

决策树的构造

决策树的优缺点

  • 优点:计算复杂度不高,输出结果容易理解,对中间值的缺失并不敏感,可以处理不相关特征数据。
  • 缺点:可能会产生过度匹配的问题
  • 适用数据类型:数值型和标称型

首先,我们讨论数学上如何使用信息论划分数据集,然后编写代码将理论应用到具体的数据集上,最后编写代码构建决策树。

在构造决策树时,我们需要解决的第一个问题:当前数据集上哪个特征在划分数据分类时起决定性作用。 为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。完成测试之后,原始数据集就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某个分支下的数据属于同一类型,则不需要再分割了。如果某分支下的数据不属于同一类型,则需要重复划分数据子集,直到所有相同类型的数据被划分为各自的子集中。

创建分支的伪代码如下所示:

1
2
3
4
5
6
7
8
9
检测数据集中的每个子项是否属于同一分类
If true return 类标签
Else
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数自己(递归)
return 分支节点

这个伪代码函数是一个递归函数,后续我们会使用python代码来实现这段伪代码。一些决策树采用二分法划分数据,本文并不采用这种方法。本文将使用 ID3 算法划分数据集。每次划分数据集时,我们只选取一个特征属性,如果训练集中存在 20 个特征,第一次我们选择哪个特征作为划分的参考属性呢?

一些常见的决策树算法:

  1. ID3(Iterative Dichotomiser 3):ID3 是最早的决策树算法之一,它使用信息增益来选择最优的特征进行分裂。然而,ID3 倾向于选择具有更多取值的特征,因此在实践中往往使用其他算法。
  2. C4.5:C4.5 是 ID3 的改进版本,它使用信息增益比来选择最优的特征。相对于 ID3,C4.5 能够处理连续特征和缺失数据,并且可以生成具有更好泛化能力的决策树。
  3. CART(Classification and Regression Trees):CART 是一种常用的决策树算法,可以用于分类和回归问题。CART 使用基尼系数(Gini Index)来选择最优的特征进行分裂,它生成的决策树是二叉树结构。
  4. CHAID(Chi-squared Automatic Interaction Detection):CHAID 是一种基于卡方检验的决策树算法,适用于分类问题。它可以处理离散和连续特征,并且能够检测特征之间的交互作用。
  5. Random Forest(随机森林):随机森林是一种集成学习方法,基于多个决策树进行预测。每个决策树都是通过随机选择样本和特征进行训练的,最后的预测结果由多个决策树的投票或平均值得出。
  6. Gradient Boosting Trees(梯度提升树):梯度提升树也是一种集成学习方法,通过迭代地训练决策树来提高预测性能。每个决策树都是在前一棵树的残差基础上进行训练的,最终的预测结果是多个决策树的加权和。

信息增益

划分数据集的大原则是:将无序的数据变得更加有序。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。

在划分数据集之前之后信息发生的变化称为:信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择

在评测哪种数据划分方式是最好的数据划分之前,我们必须学习如何计算信息增益。集合信息的度量方式称为香农熵或者简称为熵,这个名字来源于信息论之父克劳德·艾尔伍德·香农

如果看不明白什么是信息增益和熵,也不需要着急——它们自诞生的那一天起,就注定令人费解。

59f433e711f04e0cae3e7eca648e8980

熵的定义为信息的期望值,在明晰这个概念之前,我们必须知道信息的定义。如果待分类的事务可能划分在多个分类之中,则符合 $x_i$ 的信息定义为:$\large l(x_i) = - log_2{p(x_i)}$ ,其中 $\large p(x_i)$ 是选择该分类的概率。

为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:

$\large H = - \sum^{n}_{i=1}p(x_i) \log_2{p(x_i)}$,其中 $n$ 是分类的数目。

关于这两个公式的理解可以参考如何理解信息熵

下面我们将学习如何使用python计算信息熵,创建名称为trees.py的文件,如下代码为计算给定数据集的熵。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from math import log
# 计算香农信息熵
def calcShannonEnt(dataSet):
# 获取数据的长度(或者说有多少条数据)
numEntries = len(dataSet)
# 创建一个空的字典
labelCounts = {}
# 遍历每条数据
for featVec in dataSet:
# 拿到数据的最后一项,即标签项
currentLabel = featVec[-1]
# 如果这个标签不在我们创建的字典中,我们就创建它
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
# 在字典中,计数标签+1
labelCounts[currentLabel] += 1
# 香农信息熵默认为 0
shannonEnt = 0.0
# 遍历特征在字典中(所有特征)
for key in labelCounts:
# 计算每个类型特征所占用比例(概率)
prob = float(labelCounts[key])/numEntries
# 带入香农信息熵公式
shannonEnt -= prob * log(prob,2)
# 返回计算的信息熵结果
return shannonEnt

现在我们使用如下数据来测试一下我们的方法:

1
2
3
4
5
dataSet = [[1,0,'y'],[0,0,'n'],[0,1,'n'],[1,0,'y'],[0,0,'n']]
print(calcShannonEnt(dataSet))

# 输出
0.9709505944546686

熵越高,则混合的数据也越多,我们可以在测试数据集中添加更多的分类,观察熵的变化,现在我们增加第三个名为z的分类,测试熵的变化:

1
2
3
4
5
dataSet = [[1,0,'y'],[0,0,'n'],[0,1,'n'],[1,0,'y'],[0,0,'n'],[0,0,'z']]
print(calcShannonEnt(dataSet))

# 输出
1.4591479170272448

得到熵之后,我们就可以按照获取最大信息增益的方法划分数据集,下个部分我们将具体学习如何划分数据集以及如何度量信息增益。

另一个度量集合无序程度的方法是基尼不纯度,简单来说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。本文不采用基尼不纯度方法,这里不做更多说明。

划分数据集

上个部分我们学习了如何度量数据集的无序程度,分类算法除了需要测量信息熵,还需要划分数据集,度量划分数据集的熵,以便判断是否正确地划分了数据集。我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式

现在我们使用如下代码按照特征来划分数据集,代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#划分数据集,传入数据集,特征在数据集的位置,要划分的特征
def splitDataSet(dataSet, axis, value):
# 创建空的列表返回划分好的结果
retDataSet = []
# 开始遍历数据集划分数据集
for featVec in dataSet:
# 判断特征值是不是要划分的标准
if featVec[axis] == value:
# 创建划分对应特征的集合
reducedFeatVec = featVec[:axis]
# 将划分特征数据后续复制到创建的新集合中
reducedFeatVec.extend(featVec[axis+1:])
# 将划分的集合添加到分类集合中
retDataSet.append(reducedFeatVec)
# 返回划分好的集合
return retDataSet

现在我们来测试上述代码,代码示例:

1
2
3
4
5
dataSet = [[1,0,'y'],[0,0,'n'],[0,1,'n'],[1,0,'y'],[0,0,'n'],[0,0,'z']]
print(splitDataSet(dataSet,0,1))

# 输出
[[0, 'y'], [0, 'y']]

接下来我们将会遍历整个数据集,循环计算香农熵和splitDataSet()函数(划分数据集),找到最好的特征划分方式。熵计算会告诉我们如何划分数据集是最好的数据组织方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# 香农熵划分最佳数据集
def chooseBestFeatureToSplit(dataSet):
# 获取特征数量(总数)
numFeatures = len(dataSet[0]) - 1
# 获取当前没分类的信息熵
baseEntropy = calcShannonEnt(dataSet)
# 定义最好的信息熵值
bestInfoGain = 0.0
# 定义最好的分类特征
bestFeature = -1
# 遍历每个特征
for i in range(numFeatures):
# 获取当前特征的所有值
featList = [example[i] for example in dataSet]
# 去重,获取单一且不重复的当前特征枚举值
uniqueVals = set(featList)
# 定义新的信息熵
newEntropy = 0.0
# 开始遍历计算每个当前特征枚举值的信息熵,最后公式加和得到当前特征值分类后的信息熵
for value in uniqueVals:
# 根据当前特征去重的枚举值分割数据集
subDataSet = splitDataSet(dataSet, i, value)
# 获取当前枚举值的概率
prob = len(subDataSet)/float(len(dataSet))
# 根据概率*信息熵最终加和得到当前特征值的信息熵
newEntropy += prob * calcShannonEnt(subDataSet)
# 用基础的信息熵-分类后的信息熵,得到新的信息熵差异
infoGain = baseEntropy - newEntropy
# 如果这个差异是正数,即大于零,意味着按当前特征分类后的信息熵降低了,不混乱了
if (infoGain > bestInfoGain):
# 记录当前最好的信息熵
bestInfoGain = infoGain
# 记录获得最好信息熵的特征索引
bestFeature = i
# 返回得到的分类当前数据集最好的特征索引
return bestFeature

如果你实在觉得绕看不懂,可以单步调试,或者在关键的地方让它输出看看结果,多次尝试就明白了。

b9fae566feea6c184c43c992a07be226

递归构建决策树

目前我们已经学习了从数据集构造决策树算法所需要的子功能模块,其工作原理如下:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则来处理数据集。

递归的结束条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。例如下图:

image-20230709002848184

第一个结束条件可以使得算法可以终止,我们甚至可以设置算法可以划分的最大分组数目。后续还会说明其他决策树算法,例如 C4.5 和 CART,这些算法在运行时并不总是在每次划分分组时都会消耗特征。由于特征数目并不是在每次划分数据时减少,因此这些算法在实际使用时候可能会引起一些问题。目前我们并不需要考虑这个问题,只需要在算法开始运行计算列的数目,查看算法是否使用了所有属性即可。如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定改叶子节点的分类。

现在我们打开tree.py文件,在文件头部添加import operator,然后在文本中添加如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 标签投票分类(如果数据集已经处理了所有属性,但是类标签依然不是唯一的,我们需要决定如何定义该叶子节点)
def majorityCnt(classList):
# 创建一个空的标签字典
classCount={}
# 遍历数据集中的标签
for vote in classList:
# 如果标签不存在字典中,则创建标签,其初始值为 0
if vote not in classCount.keys(): classCount[vote] = 0
# 对应标签的值+1
classCount[vote] += 1
# 将字典标签按值排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
# 返回数量最多的一个标签
return sortedClassCount[0][0]

现在我们来在文件中添加最后的递归相关的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 创建决策树,传入数据集和标签列表
def createTree(dataSet,labels):
# 获取所有数据集最后一列的数据(标签)
classList = [example[-1] for example in dataSet]
# 如果传入的数据集都是一个类别,就直接返回节点
if classList.count(classList[0]) == len(classList):
return classList[0]
# 如果已经遍历完所有特征,只剩下标签列,则返回样本中出现次数最多的类别作为叶子节点
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 划分最佳数据集
bestFeat = chooseBestFeatureToSplit(dataSet)
# 得到划分最佳数据集的标签
bestFeatLabel = labels[bestFeat]
# 创建一个字典,以最佳特征的标签为键,值为空字典,用于构建决策树
myTree = {bestFeatLabel:{}}
# 删除已选择的最佳特征的标签,以便在递归调用时传递给下一层
del(labels[bestFeat])
# 获取数据集中最佳特征的所有取值
featValues = [example[bestFeat] for example in dataSet]
# 获取最佳特征的唯一取值集合
uniqueVals = set(featValues)
# 递归遍历
for value in uniqueVals:
# 创建一个副本,以便在递归调用时传递给下一层
subLabels = labels[:]
# 递归调用createTree函数,传递划分后的子数据集和剩余特征的标签,将返回的子树作为当前节点的值
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
# 返回构建好的决策树
return myTree

内容很多,很抽象是吧🤣,我也觉得很抽象,理解上面的整个代码运行过程,我们来举一个实例来理解这段代码,现在我们有如下的数据,我们需要对它们进行分类。

是否有脚 是否有鳞片 是否有鳃 是否有尾巴 【特征值】
1 1 0 0 非鱼类
1 1 1 0 鱼类
0 1 0 1 鱼类
1 0 0 0 非鱼类
0 1 0 0 鱼类

我们将上述数据转换成运行的python代码如下所示:

1
2
3
4
5
6
7
8
dataSet = data = [
[1, 1, 0, 0, '非鱼类'],
[1, 1, 1, 0, '鱼类'],
[0, 1, 0, 1, '鱼类'],
[1, 0, 0, 0, '非鱼类'],
[0, 1, 0, 0, '鱼类']
]
createTree(dataSet,['是否有脚','是否有鳞片','是否有鳃','是否有尾巴'])

很明显的看出,上面的递归代码中的labels就是表格的表头,它是用来给每一列数据进行标注的,或者说是用来解释数据的,对于计算机来说这一列并没有参考性,但是对于我们来说是有参考意义的。

现在我们运行这段代码,运行到这里:

1
2
# 获取所有数据集最后一列的数据(标签)
classList = [example[-1] for example in dataSet]

我们得到classList = ['非鱼类', '鱼类', '鱼类', '非鱼类', '鱼类'],也就说明它提取了我们的所有数据的特征。

现在运行到下面的代码部分:

1
2
3
4
5
6
# 如果传入的数据集都是一个类别,就直接返回节点
if classList.count(classList[0]) == len(classList):
return classList[0]
# 如果已经遍历完所有特征,只剩下标签列,则返回样本中出现次数最多的类别作为叶子节点
if len(dataSet[0]) == 1:
return majorityCnt(classList)

这两个部分对应的处理就是我们前面说的递归的结束情况,代码第一个if部分判断,如果给定的数据集类别中,第一个类别的数量等于该数据集所有类别的数量,就说明它们都是一个类别的,已经不需要分类了(递归的结束条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类)。

第二个if部分判断的是,如果我们的整个数据集只有一列了,那就说明只剩下了最右侧的特征值列,说明已经把属性都分类完了,这个时候也不再需要分类了(如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定改叶子节点的分类)。

现在来继续往下执行代码:

1
2
# 划分最佳数据集
bestFeat = chooseBestFeatureToSplit(dataSet)

现在我们将我们的数据集进行第一次划分,形象点来说就是决策树第一次分叉,执行chooseBestFeatureToSplit()函数后,我们的输出结果:bestFeat = 0,它告诉了我们这个数据集的第一次划分最好的属性索引是 0,对应索引的是是否有脚

具体划分原理,参考前面的信息增益部分

现在我们知道了第一次应该按什么属性来划分,代码继续运行:

1
2
# 得到划分最佳数据集的标签
bestFeatLabel = labels[bestFeat]

通过这个,就得到了前面我说的最佳划分属性的标签,就是是否有脚。我们在知道第一次划分的属性后,接下来构建决策树的雏形:

1
2
# 创建一个字典,以最佳特征的标签为键,值为空字典,用于构建决策树
myTree = {bestFeatLabel:{}}

现在我们创建了一个变量myTree来存储决策树,其中它的类型是字典类型,存储了一个key = bestFeatLabel也就是key = '是否有脚'key,它对应的value是一个空的字典,也就是代码中的{},这行代码的其最终的结果:myTree = {'是否有脚': {}}

现在代码继续执行到如下位置:

1
2
# 删除已选择的最佳特征的标签,以便在递归调用时传递给下一层
del(labels[bestFeat])

它删除了我们标签中的第一次分叉属性,也就是由之前的['是否有脚', '是否有鳞片', '是否有鳃', '是否有尾巴']变成了['是否有鳞片', '是否有鳃', '是否有尾巴']

接下来代码继续执行到如下位置:

1
2
# 获取数据集中最佳特征的所有取值
featValues = [example[bestFeat] for example in dataSet]

这句代码右侧是列表推导式,这句代码运行结果是:[1, 1, 0, 1, 0],它提取所有第一个分叉最佳属性的所有值,因为我们接下来要根据值来继续划分数据集了。

具体列表推导式是什么参考下面的相关函数说明

代码继续执行:

1
2
# 获取最佳特征的唯一取值集合
uniqueVals = set(featValues)

这句代码执行结果就是去重,它的执行结果是:{0, 1},这样我们就得到了当前最佳属性的唯一取值集合,接下来就是“分叉”,第一个“叉”是按 0 来分的,第二个“叉”是按 1 来分的。

代码继续执行:

1
2
3
4
5
6
# 递归遍历
for value in uniqueVals:
# 创建一个副本,以便在递归调用时传递给下一层
subLabels = labels[:]
# 递归调用createTree函数,传递划分后的子数据集和剩余特征的标签,将返回的子树作为当前节点的值
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)

现在我们遍历第一个“叉”,即value = 0,我们先是完全拷贝了一份labelssubLabels,接下来,我们对于

第一个分叉,先做了一个划分,即splitDataSet(dataSet, bestFeat, value),它的运行结果是返回了:[[1, 0, 1, '鱼类'], [1, 0, 0, '鱼类']],也就是第一列所有value = 0的值划分的一组(去除了第一列的值),然后形成的这个新的分组就是[[1, 0, 1, '鱼类'], [1, 0, 0, '鱼类']],转换成表格如下所示:

是否有鳞片 是否有鳃 是否有尾巴 【特征值】
1 0 1 鱼类
1 0 0 鱼类

然后这组数据再次执行createTree(splitDataSet(dataSet, bestFeat, value),subLabels)进行分类构建,但是它很明显特征值都是一个类型的,即鱼类,所以它在执行到如下代码就返回了:

1
2
3
# 如果传入的数据集都是一个类别,就直接返回节点
if classList.count(classList[0]) == len(classList):
return classList[0]

然后接下来回到递归遍历的地方,此时value = 1,也就是右分叉再次执行这个循环,知道满足前面说的两个结束条件,然后才会结束,最后返回分类好的决策树。

**最终的运行结果是:{'是否有脚': {0: '鱼类', 1: {'是否有鳃': {0: '非鱼类', 1: '鱼类'}}}}**。这个结果很难直观的来理解是吧,现在将其可视化,就是如下图所示:

Snipaste_2023-07-13_00-49-24

好了,现在你应该已经了解了如何构造决策树了,对于晦涩难懂的输出,图更加帮助我们理解分类器的内在逻辑,接下来我们来绘制决策树,来可视化我们的决策树。

使用Graphviz绘制树形图

需要说明的是 Python 本身并不具备绘制图形/图表的能力,我们需要通过拓展包来实现相关功能,在 Python 中有一些常用的包提供绘图相关操作:

  1. Matplotlib:Matplotlib是一个功能强大的绘图库,可以用于绘制各种类型的图表,包括树形图。
  2. NetworkX:NetworkX是一个专门用于创建、操作和研究复杂网络的Python库。它提供了一些功能强大的函数和算法,用于绘制树形图、图形布局和节点样式设置。
  3. Graphviz:Graphviz是一个开源的图形可视化工具包,可以用于绘制各种类型的图形,包括树形图。它使用DOT语言描述图形结构,并提供了Python接口供调用。
  4. anytree:anytree是一个轻量级的Python库,用于处理和操作树形数据结构。它提供了创建、遍历和操作树形结构的功能,并支持将树形结构可视化为文本、图形或其他格式。anytree提供了一些可选的渲染器,可以将树形结构绘制为图形。

从简单程度来说,使用Graphviz包是比较简单的,所以我采用该包进行树形图绘制演示。

如何下载安装该包,此处不再做演示,绘制图形代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import graphviz

# 创建有向图
dot = graphviz.Digraph()

# 设置节点和边的字体
dot.attr('node', fontname='SimHei')
dot.attr('edge', fontname='SimHei')

# 添加节点
dot.node('A', '是否有脚')
dot.node('B', '鱼类')
dot.node('C', '是否有鳃')
dot.node('D', '非鱼类')
dot.node('E', '鱼类')

# 添加边
dot.edge('A', 'B',label='0')
dot.edge('A', 'C',label='1')
dot.edge('C', 'D',label='0')
dot.edge('C', 'E',label='1')

# 渲染并保存图形
dot.render('tree', format='png', view=True)

其渲染结果如下所示:

Snipaste_2023-07-13_01-47-35

序列化决策树

构造决策树是很耗时的任务,如果面对的数据集很大,将会耗费更多的计算时间。然后如果我们使用创建好的决策树解决分类问题,将会大大节约时间。因此为了节省时间,最好是能够在每次执行分类时调用已经构造好的决策树。

为了解决这个问题,需要使用 Python 模块 pickle序列化对象,代码如下所示。序列化对象可以在磁盘上存储,并在我们需要的时候读取出来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 序列化决策树
def storeTree(inputTree, filename):
import pickle
# 获取文件指针,打开文件
fw = open(filename, 'wb')
# 写入文件
pickle.dump(inputTree, fw)
# 关闭文件
fw.close()

# 加载序列化的决策树
def grabTree(filename):
import pickle
# 打开文件
fr = open(filename,'rb')
# 反序列化并返回对象
return pickle.load(fr)

这样我们在构建决策树的时候就可以序列化存储起来,然后需要的时候调用出来:

1
2
3
4
5
6
7
8
9
10
11
12
13
dataSet = data = [
[1, 1, 0, 0, '非鱼类'],
[1, 1, 1, 0, '鱼类'],
[0, 1, 0, 1, '鱼类'],
[1, 0, 0, 0, '非鱼类'],
[0, 1, 0, 0, '鱼类']
]
# 序列化决策树
storeTree(createTree(dataSet,['是否有脚','是否有鳞片','是否有鳃','是否有尾巴']),'Modeltree.txt')
# 读取
model = grabTree('Modeltree.txt')
# 输出
print(model)

代码的输出结果:{'是否有脚': {0: '鱼类', 1: {'是否有鳃': {0: '非鱼类', 1: '鱼类'}}}}

通过上面的代码,我们可以将分类器存储在磁盘上,不必每次都需要学习一下,这也是决策树的优点之一,而相对于上一篇说明的KNN(k-近邻算法)就无法持久化分类器。

使用决策树预测隐形眼镜类型

使用小数据集,我们就可以利用决策树学到很多知识:眼科医生是如何判断患者需要佩戴的镜片类型?一旦理解了决策树的工作原理,我们甚至也可以帮助人们判断需要佩戴的镜片类型。

关于隐形眼镜的数据集在文本的最后,相关数据部分,将数据记得保存在一个txt中。

现在我们在 Python 中调用如下代码:

1
2
3
4
5
6
7
8
# 读取数据集
fr = open('treedata.txt')
# 处理成需要的格式
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
# 对应的标签
lensesLabels=['age', 'prescript', 'astigmatic', 'tearRate']
# 生成决策树
print(createTree (lenses,lensesLabels))

输出结果:{'tearRate': {'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'young': 'hard', 'presbyopic': 'no lenses', 'pre': 'no lenses'}}, 'myope': 'hard'}}, 'no': {'age': {'young': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'pre': 'soft'}}}}, 'reduced': 'no lenses'}}。可视化后如下所示:

Snipaste_2023-07-13_02-26-53

从上图我们也可以发现,医生最多需要四个问题就能确定患者需要佩戴哪种类型的隐形眼镜。

上图的决策树也非常好的匹配了实验数据,然而这些匹配选项可能太多了。我们将这种问题称之为过度匹配。为了减少过度匹配,我们可以裁剪决策树,去掉一些不必要的叶子节点。如果叶子节点只能增加少量信息,则可以删除该节点,将它并入其他叶子节点中。我们将会在后续讨论这个问题。

相关函数说明

extend()

extend() 是 Python 列表对象的一个方法,用于将一个可迭代对象中的元素逐个添加到列表中。它会修改原始列表,将可迭代对象中的元素追加到列表的末尾。代码示例:

1
2
3
4
5
6
7
my_list = [1, 2, 3]
another_list = [4, 5, 6]
my_list.extend(another_list)
print(my_list)

# 输出
[1, 2, 3, 4, 5, 6]

你会发现它和append()函数很像,具体不同看下面

append()

append() 是 Python 列表对象的一个方法,用于将一个元素添加到列表的末尾。它会修改原始列表,将元素追加到列表的最后一个位置。代码示例:

1
2
3
4
5
6
7
a = [1,2,3]
b = [4,5,6]
a.append(b)
print(a)

# 输出
[1, 2, 3, [4, 5, 6]]

列表推导式

在划分数据集中,featList = [example[i] for example in dataSet]这就是一个列表推导式,它用于提取数据集中每个样本的第 i 个特征的取值。代码示例:

1
2
3
4
5
6
7
8
9
dataSet = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
]
print([i[1] for i in dataSet])

# 输出
[2, 5, 8]

set()

set() 是一个Python内置函数,用于创建一个无序、不重复元素的集合。集合是一种可变的数据类型,它可以存储各种不同的元素,但不允许有重复的元素。代码示例:

1
2
3
4
5
numbers = [1, 2, 3, 3, 4, 5, 5]
print(set(numbers))

# 输出
{1, 2, 3, 4, 5}

相关数据

隐形眼镜数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
young	myope	no	reduced	no lenses
young myope no normal soft
young myope yes reduced no lenses
young myope yes normal hard
young hyper no reduced no lenses
young hyper no normal soft
young hyper yes reduced no lenses
young hyper yes normal hard
pre myope no reduced no lenses
pre myope no normal soft
pre myope yes reduced no lenses
pre myope yes normal hard
pre hyper no reduced no lenses
pre hyper no normal soft
pre hyper yes reduced no lenses
pre hyper yes normal no lenses
presbyopic myope no reduced no lenses
presbyopic myope no normal no lenses
presbyopic myope yes reduced no lenses
presbyopic myope yes normal hard
presbyopic hyper no reduced no lenses
presbyopic hyper no normal soft
presbyopic hyper yes reduced no lenses
presbyopic hyper yes normal no lenses

End

本文使用的算法是 ID3 ,它是一个好的算法但是并不完美。ID3 算法无法直接处理数值型数据,尽管我们可以量化的方法将数值型数据转换为标称型数值,但是如果存在太多特征划分,ID3 算法仍然面临其他问题。

后续我们将会学习另一个构造决策树的算法 CART,它使用基尼系数(Gini Index)来选择最优的特征进行分裂。