regTrees.py
from numpy import *
def loadDataSet(fileName):
"""loadDataSet(解析每一行,并转化为float类型)
Desc:该函数读取一个以 tab 键为分隔符的文件,然后将每行的内容保存成一组浮点数
Args:
fileName 文件名
Returns:
dataMat 每一行的数据集array类型
Raises:
"""
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine))
dataMat.append(fltLine)
return dataMat
def binSplitDataSet(dataSet, feature, value):
"""binSplitDataSet(将数据集,按照feature列的value进行 二元切分)
Description:在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回。
Args:
dataMat 数据集
feature 待切分的特征列
value 特征列要比较的值
Returns:
mat0 小于等于 value 的数据集在左边
mat1 大于 value 的数据集在右边
Raises:
"""
mat0 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
mat1 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
return mat0, mat1
def regLeaf(dataSet):
return mean(dataSet[:, -1])
def regErr(dataSet):
return var(dataSet[:, -1]) * shape(dataSet)[0]
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
"""chooseBestSplit(用最佳方式切分数据集 和 生成相应的叶节点)
Args:
dataSet 加载的原始数据集
leafType 建立叶子点的函数
errType 误差计算函数(求总方差)
ops [容许误差下降值,切分的最少样本数]。
Returns:
bestIndex feature的index坐标
bestValue 切分的最优值
Raises:
"""
tolS = ops[0]
tolN = ops[1]
if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m, n = shape(dataSet)
S = errType(dataSet)
bestS, bestIndex, bestValue = inf, 0, 0
for featIndex in range(n - 1):
for splitVal in set(dataSet[:, featIndex].T.tolist()[0]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex, bestValue
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
"""createTree(获取回归树)
Description:递归函数:如果构建的是回归树,该模型是一个常数,如果是模型树,其模型师一个线性方程。
Args:
dataSet 加载的原始数据集
leafType 建立叶子点的函数
errType 误差计算函数
ops=(1, 4) [容许误差下降值,切分的最少样本数]
Returns:
retTree 决策树最后的结果
"""
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
'''
*** 最后的返回结果是最后剩下的 val,也就是len小于topN的集合
'''
if feat is None:
return val
retTree = {'spInd': feat, 'spVal': val}
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
def isTree(obj):
"""
Desc:
测试输入变量是否是一棵树,即是否是一个字典
Args:
obj -- 输入变量
Returns:
返回布尔类型的结果。如果 obj 是一个字典,返回true,否则返回 false
"""
return type(obj).__name__ == 'dict'
def getMean(tree):
"""
Desc:
从上往下遍历树直到叶节点为止,如果找到两个叶节点则计算它们的平均值。
对 tree 进行塌陷处理,即返回树平均值。
Args:
tree -- 输入的树
Returns:
返回 tree 节点的平均值
"""
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right']) / 2.0
def prune(tree, testData):
"""
Desc:
从上而下找到叶节点,用测试数据集来判断将这些叶节点合并是否能降低测试误差
Args:
tree -- 待剪枝的树
testData -- 剪枝所需要的测试数据 testData
Returns:
tree -- 剪枝完成的树
"""
if shape(testData)[0] == 0:
return getMean(tree)
if isTree(tree['right']) or isTree(tree['left']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']):
tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2))
treeMean = (tree['left'] + tree['right']) / 2.0
errorMerge = sum(power(testData[:, -1] - treeMean, 2))
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else:
return tree
else:
return tree
def modelLeaf(dataSet):
"""
Desc:
当数据不再需要切分的时候,生成叶节点的模型。
Args:
dataSet -- 输入数据集
Returns:
调用 linearSolve 函数,返回得到的 回归系数ws
"""
ws, X, Y = linearSolve(dataSet)
return ws
def modelErr(dataSet):
"""
Desc:
在给定数据集上计算误差。
Args:
dataSet -- 输入数据集
Returns:
调用 linearSolve 函数,返回 yHat 和 Y 之间的平方误差。
"""
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat, 2))
def linearSolve(dataSet):
"""
Desc:
将数据集格式化成目标变量Y和自变量X,执行简单的线性回归,得到ws
Args:
dataSet -- 输入数据
Returns:
ws -- 执行线性回归的回归系数
X -- 格式化自变量X
Y -- 格式化目标变量Y
"""
m, n = shape(dataSet)
X = mat(ones((m, n)))
Y = mat(ones((m, 1)))
X[:, 1: n] = dataSet[:, 0: n - 1]
Y = dataSet[:, -1]
xTx = X.T * X
if linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverse,\ntry increasing the second value of ops')
ws = xTx.I * (X.T * Y)
return ws, X, Y
def regTreeEval(model, inDat):
"""
Desc:
对 回归树 进行预测
Args:
model -- 指定模型,可选值为 回归树模型 或者 模型树模型,这里为回归树
inDat -- 输入的测试数据
Returns:
float(model) -- 将输入的模型数据转换为 浮点数 返回
"""
return float(model)
def modelTreeEval(model, inDat):
"""
Desc:
对 模型树 进行预测
Args:
model -- 输入模型,可选值为 回归树模型 或者 模型树模型,这里为模型树模型,实则为 回归系数
inDat -- 输入的测试数据
Returns:
float(X * model) -- 将测试数据乘以 回归系数 得到一个预测值 ,转化为 浮点数 返回
"""
n = shape(inDat)[1]
X = mat(ones((1, n + 1)))
X[:, 1: n + 1] = inDat
return float(X * model)
def treeForeCast(tree, inData, modelEval=regTreeEval):
"""
Desc:
对特定模型的树进行预测,可以是 回归树 也可以是 模型树
Args:
tree -- 已经训练好的树的模型
inData -- 输入的测试数据,只有一行
modelEval -- 预测的树的模型类型,可选值为 regTreeEval(回归树) 或 modelTreeEval(模型树),默认为回归树
Returns:
返回预测值
"""
if not isTree(tree):
return modelEval(tree, inData)
if inData[0, tree['spInd']] <= tree['spVal']:
if isTree(tree['left']):
return treeForeCast(tree['left'], inData, modelEval)
else:
return modelEval(tree['left'], inData)
else:
if isTree(tree['right']):
return treeForeCast(tree['right'], inData, modelEval)
else:
return modelEval(tree['right'], inData)
def createForeCast(tree, testData, modelEval=regTreeEval):
"""
Desc:
调用 treeForeCast ,对特定模型的树进行预测,可以是 回归树 也可以是 模型树
Args:
tree -- 已经训练好的树的模型
testData -- 输入的测试数据
modelEval -- 预测的树的模型类型,可选值为 regTreeEval(回归树) 或 modelTreeEval(模型树),默认为回归树
Returns:
返回预测值矩阵
"""
m = len(testData)
yHat = mat(zeros((m, 1)))
for i in range(m):
yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)
return yHat
if __name__ == "__main__":
trainMat = mat(loadDataSet('data/bikeSpeedVsIq_train.txt'))
testMat = mat(loadDataSet('data/bikeSpeedVsIq_test.txt'))
myTree1 = createTree(trainMat, ops=(1, 20))
print(myTree1)
yHat1 = createForeCast(myTree1, testMat[:, 0])
print("--------------\n")
print("regTree:", corrcoef(yHat1, testMat[:, 1], rowvar=0)[0, 1])
myTree2 = createTree(trainMat, modelLeaf, modelErr, ops=(1, 20))
yHat2 = createForeCast(myTree2, testMat[:, 0], modelTreeEval)
print(myTree2)
print("modelTree:", corrcoef(yHat2, testMat[:, 1], rowvar=0)[0, 1])
ws, X, Y = linearSolve(trainMat)
print(ws)
m = len(testMat[:, 0])
yHat3 = mat(zeros((m, 1)))
for i in range(shape(testMat)[0]):
yHat3[i] = testMat[i, 0] * ws[1, 0] + ws[0, 0]
print("lr:", corrcoef(yHat3, testMat[:, 1], rowvar=0)[0, 1])
treeExplore.py
import regTrees
from tkinter import *
from numpy import *
import matplotlib
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
matplotlib.use('TkAgg')
def test_widget_text(root):
mylabel = Label(root, text="helloworld")
mylabel.grid()
def reDraw(tolS, tolN):
reDraw.f.clf()
reDraw.a = reDraw.f.add_subplot(111)
if chkBtnVar.get():
if tolN < 2:
tolN = 2
myTree = regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf, regTrees.modelErr, (tolS, tolN))
yHat = regTrees.createForeCast(myTree, reDraw.testDat, regTrees.modelTreeEval)
else:
myTree = regTrees.createTree(reDraw.rawDat, ops=(tolS, tolN))
yHat = regTrees.createForeCast(myTree, reDraw.testDat)
reDraw.a.scatter(reDraw.rawDat[:, 0].A, reDraw.rawDat[:, 1].A, s=5)
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0, c='red')
reDraw.canvas.draw()
def getInputs():
try:
tolN = int(tolNentry.get())
except:
tolN = 10
print("enter Integer for tolN")
tolNentry.delete(0, END)
tolNentry.insert(0, '10')
try:
tolS = float(tolSentry.get())
except:
tolS = 1.0
print("enter Float for tolS")
tolSentry.delete(0, END)
tolSentry.insert(0, '1.0')
return tolN, tolS
def drawNewTree():
tolN, tolS = getInputs()
reDraw(tolS, tolN)
def main(root):
Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
Label(root, text="tolN").grid(row=1, column=0)
global tolNentry
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0, '10')
Label(root, text="tolS").grid(row=2, column=0)
global tolSentry
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
Button(root, text="确定", command=drawNewTree).grid(row=1, column=2, rowspan=3)
global chkBtnVar
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="Model Tree", variable=chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)
Button(root, text="退出", fg="black", command=quit).grid(row=1, column=2)
reDraw.f = Figure(figsize=(5, 4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.draw()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
reDraw.rawDat = mat(regTrees.loadDataSet('data/sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
reDraw(1.0, 10)
if __name__ == "__main__":
root = Tk()
main(root)
root.mainloop()