西瓜书上朴素贝叶斯的实现,完全按照书上的步骤

注:西瓜书上的数据有错误如P152的5/8=0.375,所以代码的计算是正确的。如果读者想要“拉普拉斯修正“的源码请访问https://download.csdn.net/download/song91425/10385345 。 所谓的拉普拉斯就是避免出现概率为0的情况。

import numpy as np


def load_data(filepath):
    '''
    :arg filepath  filepath是数据的路径
    :fun 加载数据:1,青绿,蜷缩,浊响,清晰,凹陷,硬滑,0.697,0.46,是
    :return 加载后的数据
    '''

    file_object = open(filepath, encoding='UTF-8')
    train_data = []
    file_object.readline()
    while 1:
        data = file_object.readline()
        if not data:
            break
        else:
            train_data.append(data)
    file_object.close()
    test = []
    for s in train_data:
        test.append(s.replace('\n', '').split(','))   #去掉\n和把数据按照’,‘分割再存
    return test


def count_labels(data):
    '''

    :param data:数据集
    :return: 返回好瓜和坏瓜的数目
    '''
    yes = 0
    no = 0
    for s in range(data.__len__()):
        if data[s][-1] == '是':
            yes += 1
        else:
            no += 1
    return yes, no


def handle_one_data(data, attr, location, yes, no):
    '''
    :param data: 数据集
    :param attr: 要传入的属性
    :param location: 传入属性的位置
    :param yes: 好瓜数量
    :param no: 坏瓜数量
    :return: 返回该属性在好瓜或者是坏瓜的前提下的概率
    '''
    attr_y, attr_n = 0, 0
    for s in range(data.__len__()):
        if data[s][-1] == '是':
            if data[s][location] == attr:
                attr_y += 1
        else:
            if data[s][location] == attr:
                attr_n += 1
    return attr_y / yes, attr_n / no


def handle_data(data):
    '''

    :param data: 数据集
    :return: 对密度和含糖率的均值和标准差
    '''
    midu_y = []
    tiandu_y = []
    midu_n = []
    tiandu_n = []
    for s in range(data.__len__()):
        if data[s][-1] == '是':
            midu_y.append(np.float(data[s][-3]))
            tiandu_y.append(np.float(data[s][-2]))
        else:
            midu_n.append(np.float(data[s][-3]))
            tiandu_n.append(np.float(data[s][-2]))
    m_midu_y = np.mean(midu_y)
    m_midu_n = np.mean(midu_n)
    t_tiandu_y = np.mean(tiandu_y)
    t_tiandu_n = np.mean(tiandu_n)
    std_midu_y = np.std(midu_y)
    std_midu_n = np.std(midu_n)
    std_tiandu_y = np.std(tiandu_y)
    std_tiandu_n = np.std(tiandu_n)

    return m_midu_y, m_midu_n, t_tiandu_y, t_tiandu_n, std_midu_y, std_midu_n, std_tiandu_y, std_tiandu_n


def show_result(p_yes, p_no):
    '''

    :param p_yes: 在好瓜的前提下,测试数据各个属性的概率
    :param p_no: 在是坏瓜的前提下,测试数据的各个属性的概率
    :return: 是好瓜或者是坏瓜
    '''
    p1 = 1.0
    p2 = 1.0
    for s in range(p_yes.__len__()):
        p1 *= np.float(p_yes[s])
        p2 *= np.float(p_no[s])
    if p1 > p2:
        print("好瓜", p1, p2)
    else:
        print("坏瓜", p1, p2)


def count_attr_dis(data):
    '''

    :param data: 数据集
    :return: 各个属性取值的个数
    '''
    count = [] # 记录各个属性的取值有多少个不同
    for i in range(data[0].__len__()):
        if i == 0 or i == 7 or i == 8: # 去掉编号,密度,甜度这个属性
           continue
        d = []
        for s in range(data.__len__()):
            if not d.__contains__(data[s][i]): # 如果读到的属性不包含在d里就加入到d中
                d.append(data[s][i])
        count.append(d.__len__())  # 统计属性取值不同的个数
    return count


if __name__ == '__main__':
    filepath = 'D:\\pycharm\\bayes.txt'
    data = load_data(filepath)
    m_midu_y, m_midu_n, t_tiandu_y, t_tiandu_n, std_midu_y, std_midu_n, std_tiandu_y, std_tiandu_n = handle_data(data)
    yes, no = count_labels(data)
    p_yes = [yes / (yes + no)]
    p_no = [no / (yes + no)]
    test_data = ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.697, 0.460]
    for s in range(6):
        s_yes, s_no = handle_one_data(data, test_data[s], s+1, yes, no)
        p_yes.append(s_yes)
        p_no.append(s_no)

    #求西瓜书公式(7.18)
    p_yes.append(1/(np.sqrt(2*np.pi) * std_midu_y) * np.exp((-1) * ((test_data[6] - m_midu_y)**2)/std_midu_y**2))
    p_no.append(1/(np.sqrt(2 * np.pi) * std_midu_n) * np.exp((-1) * ((test_data[6] - m_midu_n) ** 2) / std_midu_n ** 2))

    p_yes.append(1/(np.sqrt(2 * np.pi) * std_tiandu_y) * np.exp((-1) * ((test_data[7] - t_tiandu_y) ** 2) / std_tiandu_y ** 2))
    p_no.append(1/(np.sqrt(2 * np.pi) * std_tiandu_n) * np.exp((-1) * ((test_data[7] - t_tiandu_n) ** 2) / std_tiandu_n ** 2))

    print(p_yes)
    print(p_no)
    show_result(p_yes, p_no)

    # 防止某个属性的取值个数为0的概率出现,采用拉皮拉斯修正(各个属性不同取值已经完成如函数count_attr_dis)

    print(count_attr_dis(data), '不同属性取值')

猜你喜欢

转载自blog.csdn.net/song91425/article/details/80157884