决策树(ID3)算法的实现和运用

算法,机器学习 2016-01-25

  决策树简单来说是一种可用于问题判定的树形结构。如下图,就是一个决策树: 1.png

  椭圆表示终止模块,矩形表示判断模块,箭头称为分支。决策树虽然看起来简单,但在数据挖掘,人工智能等领域有着广泛应用。这里介绍的决策树称为ID3,它无法直接用于处理数值型数据。

  算法分析   优点:复杂度不高,对中间值缺失不敏感,可处理不相关数据   缺点:会产生过度匹配问题   适用数据类型:标称型

  要构造决策树,很重要的一步就是划分数据集。为了找到决定性特征,必须对每一种特征进行评估,并对原始数据集进行划分,划分后的子集分布于第一个决策点的所有分支上,如果某分支下的数据属于同一类型,则无需进一步划分,否则重复划分直至相同。那么,划分数据集的原则是什么呢? 将无序数据变得有序。这里涉及到信息熵的概念,详细介绍可以参考信息熵,这里我们只需要知道计算公式:H(x) = E[I(xi)] = E[ log(2,1/p(xi)) ] = -∑p(xi)log(2,p(xi)) (i=1,2,..n)

  下面给出具体实现:

  

from math import log
import operator

def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers']
    #改变离散值
    return dataSet, labels

def calcShannonEnt(dataSet):        #计算数据集的信息熵
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet: #为所有可能的分类创建字典
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2) #以2为底求对数
    return shannonEnt

def splitDataSet(dataSet, axis, value):     #划分数据集
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]     #抽取用于分割
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):   #选择最好的划分方式
    numFeatures = len(dataSet[0]) - 1      #最后一列作为标签
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):        #遍历所有特征
        featList = [example[i] for example in dataSet]#创建唯一分类标签列表
        uniqueVals = set(featList)       #获取一组独特的值
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)     
        infoGain = baseEntropy - newEntropy     #计算每种方式的信息熵
        if (infoGain > bestInfoGain):       #比较
            bestInfoGain = infoGain         #如果找到更好的,替换之
            bestFeature = i
    return bestFeature                      #返回一个整数

def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet,labels):      #用递归的方法创建决策树
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList): 
        return classList[0]          #类别完全相同则停止划分
    if len(dataSet[0]) == 1:         #当数据集中没有其它特征时停止划分
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]       #复制所有标签, 以便树不要弄乱标签
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree                            

def classify(inputTree,featLabels,testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel

def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

下面我们用隐形眼镜数据集进行测试,该数据集包含很多患者的眼部状况及医生推荐的眼镜类型,我们可以用决策树对结果进行预测。数据集下载 新建一个test.py文件:

   #encoding:utf-8

import trees
fr=open('lenses.txt')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree=trees.createTree(lenses,lensesLabels)
print(lensesTree) 

运行后输出:

  3.png

可以依次构建出了一个决策树。

参考:维基百科:ID3


本文由 Tony 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。

如果对您有用,您的支持将鼓励我继续创作!