上一節實現了決策樹,但只是使用包含樹結構信息的嵌套字典來實現,其表示形式較難理解,顯然,繪制直觀的二叉樹圖是十分必要的。Python沒有提供自帶的繪制樹工具,需要自己編寫函數,結合Matplotlib庫創建自己的樹形圖。這一部分的代碼多而復雜,涉及二維坐標運算;書里的代碼雖然可用,但函數和各種變量非常多,感覺非常凌亂,同時大量使用遞歸,因此只能反復研究,反反復復用了一天多時間,才差不多搞懂,因此需要備注一下。
**一.繪制屬性圖**
這里使用Matplotlib的注解工具annotations實現決策樹繪制的各種細節,包括生成節點處的文本框、添加文本注釋、提供對文字著色等等。在畫一整顆樹之前,最好先掌握單個樹節點的繪制。一個簡單實例如下:
~~~
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 04 01:15:01 2015
@author: Herbert
"""
import matplotlib.pyplot as plt
nonLeafNodes = dict(boxstyle = "sawtooth", fc = "0.8")
leafNodes = dict(boxstyle = "round4", fc = "0.8")
line = dict(arrowstyle = "<-")
def plotNode(nodeName, targetPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeName, xy = parentPt, xycoords = \
'axes fraction', xytext = targetPt, \
textcoords = 'axes fraction', va = \
"center", ha = "center", bbox = nodeType, \
arrowprops = line)
def createPlot():
fig = plt.figure(1, facecolor = 'white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon = False)
plotNode('nonLeafNode', (0.2, 0.1), (0.4, 0.8), nonLeafNodes)
plotNode('LeafNode', (0.8, 0.1), (0.6, 0.8), leafNodes)
plt.show()
createPlot()
~~~
輸出結果:

該實例中,`plotNode()`函數用于繪制箭頭和節點,該函數每調用一次,將繪制一個箭頭和一個節點。后面對于該函數有比較詳細的解釋。`createPlot()`函數創建了輸出圖像的對話框并對齊進行一些簡單的設置,同時調用了兩次`plotNode()`,生成一對節點和指向節點的箭頭。
**繪制整顆樹**
這部分的函數和變量較多,為方便日后擴展功能,需要給出必要的標注:
~~~
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 04 01:15:01 2015
@author: Herbert
"""
import matplotlib.pyplot as plt
# 部分代碼是對繪制圖形的一些定義,主要定義了文本框和剪頭的格式
nonLeafNodes = dict(boxstyle = "sawtooth", fc = "0.8")
leafNodes = dict(boxstyle = "round4", fc = "0.8")
line = dict(arrowstyle = "<-")
# 使用遞歸計算樹的葉子節點數目
def getLeafNum(tree):
num = 0
firstKey = tree.keys()[0]
secondDict = tree[firstKey]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
num += getLeafNum(secondDict[key])
else:
num += 1
return num
# 同葉子節點計算函數,使用遞歸計算決策樹的深度
def getTreeDepth(tree):
maxDepth = 0
firstKey = tree.keys()[0]
secondDict = tree[firstKey]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
depth = getTreeDepth(secondDict[key]) + 1
else:
depth = 1
if depth > maxDepth:
maxDepth = depth
return maxDepth
# 在前面例子已實現的函數,用于注釋形式繪制節點和箭頭
def plotNode(nodeName, targetPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeName, xy = parentPt, xycoords = \
'axes fraction', xytext = targetPt, \
textcoords = 'axes fraction', va = \
"center", ha = "center", bbox = nodeType, \
arrowprops = line)
# 用于繪制剪頭線上的標注,涉及坐標計算,其實就是兩個點坐標的中心處添加標注
def insertText(targetPt, parentPt, info):
xCoord = (parentPt[0] - targetPt[0]) / 2.0 + targetPt[0]
yCoord = (parentPt[1] - targetPt[1]) / 2.0 + targetPt[1]
createPlot.ax1.text(xCoord, yCoord, info)
# 實現整個樹的繪制邏輯和坐標運算,使用的遞歸,重要的函數
# 其中兩個全局變量plotTree.xOff和plotTree.yOff
# 用于追蹤已繪制的節點位置,并放置下個節點的恰當位置
def plotTree(tree, parentPt, info):
# 分別調用兩個函數算出樹的葉子節點數目和樹的深度
leafNum = getLeafNum(tree)
treeDepth = getTreeDepth(tree)
firstKey = tree.keys()[0] # the text label for this node
firstPt = (plotTree.xOff + (1.0 + float(leafNum)) / 2.0/plotTree.totalW,\
plotTree.yOff)
insertText(firstPt, parentPt, info)
plotNode(firstKey, firstPt, parentPt, nonLeafNodes)
secondDict = tree[firstKey]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], firstPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \
firstPt, leafNodes)
insertText((plotTree.xOff, plotTree.yOff), firstPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# 以下函數執行真正的繪圖操作,plotTree()函數只是樹的一些邏輯和坐標運算
def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon = False) #, **axprops)
# 全局變量plotTree.totalW和plotTree.totalD
# 用于存儲樹的寬度和樹的深度
plotTree.totalW = float(getLeafNum(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), ' ')
plt.show()
# 一個小的測試集
def retrieveTree(i):
listOfTrees = [{'no surfacing':{0: 'no', 1:{'flippers':{0:'no', 1:'yes'}}}},\
{'no surfacing':{0: 'no', 1:{'flippers':{0:{'head':{0:'no', \
1:'yes'}}, 1:'no'}}}}]
return listOfTrees[i]
createPlot(retrieveTree(1)) # 調用測試集中一棵樹進行繪制
~~~
`retrieveTree()`函數中包含兩顆獨立的樹,分別輸入參數即可返回樹的參數`tree`,最后執行`createPlot(tree)`即得到畫圖的結果,如下所示:


書中關于遞歸計算樹的葉子節點和深度這部分十分簡單,在編寫繪制屬性圖的函數時,難度在于這本書中一些繪圖坐標的取值以及在計算節點坐標所作的處理,書中對于這部分的解釋比較散亂。博客:[http://www.cnblogs.com/fantasy01/p/4595902.html](http://www.cnblogs.com/fantasy01/p/4595902.html)?給出了十分詳盡的解釋,包括坐標的求解和公式的分析,以下只摘取一部分作為了解:
這里說一下具體繪制的時候是利用自定義,如下圖:
這里繪圖,作者選取了一個很聰明的方式,并不會因為樹的節點的增減和深度的增減而導致繪制出來的圖形出現問題,當然不能太密集。這里利用整 棵樹的葉子節點數作為份數將整個x軸的長度進行平均切分,利用樹的深度作為份數將y軸長度作平均切分,并利用plotTree.xOff作為最近繪制的一 個葉子節點的x坐標,當再一次繪制葉子節點坐標的時候才會plotTree.xOff才會發生改變;用plotTree.yOff作為當前繪制的深 度,plotTree.yOff是在每遞歸一層就會減一份(上邊所說的按份平均切分),其他時候是利用這兩個坐標點去計算非葉子節點,這兩個參數其實就可 以確定一個點坐標,這個坐標確定的時候就是繪制節點的時候
**`plotTree`函數的整體步驟分為以下三步:**
1. 繪制自身
2. 若當前子節點不是葉子節點,遞歸
3. 若當子節點為葉子節點,繪制該節點
以下是`plotTree`和`createPlot`函數的詳細解析,因此把兩個函數的代碼單獨拿出來了:
~~~
# 實現整個樹的繪制邏輯和坐標運算,使用的遞歸,重要的函數
# 其中兩個全局變量plotTree.xOff和plotTree.yOff
# 用于追蹤已繪制的節點位置,并放置下個節點的恰當位置
def plotTree(tree, parentPt, info):
# 分別調用兩個函數算出樹的葉子節點數目和樹的深度
leafNum = getLeafNum(tree)
treeDepth = getTreeDepth(tree)
firstKey = tree.keys()[0] # the text label for this node
firstPt = (plotTree.xOff + (1.0 + float(leafNum)) / 2.0/plotTree.totalW,\
plotTree.yOff)
insertText(firstPt, parentPt, info)
plotNode(firstKey, firstPt, parentPt, nonLeafNodes)
secondDict = tree[firstKey]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], firstPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \
firstPt, leafNodes)
insertText((plotTree.xOff, plotTree.yOff), firstPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# 以下函數執行真正的繪圖操作,plotTree()函數只是樹的一些邏輯和坐標運算
def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon = False) #, **axprops)
# 全局變量plotTree.totalW和plotTree.totalD
# 用于存儲樹的寬度和樹的深度
plotTree.totalW = float(getLeafNum(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), ' ')
plt.show()
~~~
首先代碼對整個畫圖區間根據葉子節點數和深度進行平均切分,并且`x`和`y`軸的總長度均為`1`,如同下圖:

**解釋如下**:
1.圖中的方形為非葉子節點的位置,`@`是葉子節點的位置,因此上圖的一個表格的長度應該為:?`1/plotTree.totalW`,但是葉子節點的位置應該為`@`所在位置,則在開始的時候?`plotTree.xOff`?的賦值為:?`-0.5/plotTree.totalW`,即意為開始`x`?軸位置為第一個表格左邊的半個表格距離位置,這樣作的好處是在以后確定`@`位置時候可以直接加整數倍的?`1/plotTree.totalW`。
2.plotTree函數中的一句代碼如下:
~~~
firstPt = (plotTree.xOff + (1.0 + float(leafNum)) / 2.0/ plotTree.totalW, plotTree.yOff)
~~~
其中,變量`plotTree.xOff`即為最近繪制的一個葉子節點的`x`軸坐標,在確定當前節點位置時每次只需確定當前節點有幾個葉子節點,因此其葉子節點所占的總距離就確定了即為:?`float(numLeafs)/plotTree.totalW`,因此當前節點的位置即為其所有葉子節點所占距離的中間即一半為:`float(numLeafs)/2.0/plotTree.totalW`,但是由于開始`plotTree.xOff`賦值并非從`0`開始,而是左移了半個表格,因此還需加上半個表格距離即為:`1/2/plotTree.totalW`,則加起來便為:?`(1.0 + float(numLeafs))/2.0/plotTree.totalW`,因此偏移量確定,則`x`軸的位置變為:`plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW`
3.關于`plotTree()`函數的參數
~~~
plotTree(inTree, (0.5, 1.0), ' ')
~~~
對`plotTree()`函數的第二個參數賦值為`(0.5, 1.0)`,因為開始的根節點并不用劃線,因此父節點和當前節點的位置需要重合,利用2中的確定當前節點的位置為`(0.5, 1.0)`。
**總結**:利用這樣的逐漸增加`x`?軸的坐標,以及逐漸降低`y`軸的坐標能能夠很好的將樹的葉子節點數和深度考慮進去,因此圖的邏輯比例就很好的確定了,即使圖像尺寸改變,我們仍然可以看到按比例繪制的樹形圖。
**二.使用決策樹預測隱形眼鏡類型**
這里實現一個例子,即利用決策樹預測一個患者需要佩戴的隱形眼鏡類型。以下是整個預測的大體步驟:
1. 收集數據:使用書中提供的小型數據集
2. 準備數據:對文本中的數據進行預處理,如解析數據行
3. 分析數據:快速檢查數據,并使用`createPlot()`函數繪制最終的樹形圖
4. 訓練決策樹:使用`createTree()`函數訓練
5. 測試決策樹:編寫簡單的測試函數驗證決策樹的輸出結果&繪圖結果
6. 使用決策樹:這部分可選擇將訓練好的決策樹進行存儲,以便隨時使用
此處新建腳本文件`saveTree.py`,將訓練好的決策樹保存在磁盤中,這里需要使用Python模塊的`pickle`序列化對象。`storeTree()`函數負責把`tree`存放在當前目錄下的`filename(.txt)`文件中,而`getTree(filename)`則是在當前目錄下的`filename(.txt)`文件中讀取決策樹的相關數據。
~~~
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 05 01:56:04 2015
@author: Herbert
"""
import pickle
def storeTree(tree, filename):
fw = open(filename, 'w')
pickle.dump(tree, fw)
fw.close()
def getTree(filename):
fr = open(filename)
return pickle.load(fr)
~~~
以下代碼實現了決策樹預測隱形眼鏡模型的實例,使用的數據集是隱形眼鏡數據集,它包含很多患者的眼部狀況的觀察條件以及醫生推薦的隱形眼鏡類型,其中隱形眼鏡類型包括:硬材質`(hard)`、軟材質`(soft)`和不適合佩戴隱形眼鏡`(no lenses)`?, 數據來源于UCI數據庫。代碼最后調用了之前準備好的`createPlot()`函數繪制樹形圖。
~~~
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 05 14:21:43 2015
@author: Herbert
"""
import tree
import plotTree
import saveTree
fr = open('lenses.txt')
lensesData = [data.strip().split('\t') for data in fr.readlines()]
lensesLabel = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = tree.buildTree(lensesData, lensesLabel)
#print lensesData
print lensesTree
print plotTree.createPlot(lensesTree)
~~~

可以看到,前期實現了決策樹的構建和繪制,使用不同的數據集都可以得到很直觀的結果,從圖中可以看到,沿著決策樹的不同分支,可以得到不同患者需要佩戴的隱形眼鏡的類型。
**三.關于本章使用的決策樹的總結**
回到決策樹的算法層面,以上代碼的實現基于ID3決策樹構造算法,它是一個非常經典的算法,但其實缺點也不少。實際上決策樹的使用中常常會遇到一個問題,即**“過度匹配”**。有時候,過多的分支選擇或匹配選項會給決策帶來負面的效果。為了減少過度匹配的問題,通常算法設計者會在一些實際情況中選擇**“剪枝”**。簡單說來,如果葉子節點只能增加少許信息,則可以刪除該節點。
另外,還有幾種目前很流行的決策樹構造算法:C4.5、C5.0和CART,后期需繼續深入研究。
參考資料:[http://blog.sina.com.cn/s/blog_7399ad1f01014wec.html](http://blog.sina.com.cn/s/blog_7399ad1f01014wec.html)