十四、聚类实战——图片压缩

对同一像素点值的像素点归为一类,通过平均值进行取代,从而将图像进行压缩并且保证图像尽可能不失真,关键信息仍保留。

from PIL import Image
import numpy as np
from sklearn.cluster import KMeans
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def restore_image(cb, cluster, shape):
    row, col, dummy = shape
    image = np.empty((row, col, 3))
    index = 0
    for r in range(row):
        for c in range(col):
            image[r, c] = cb[cluster[index]]
            index += 1
    return image


def show_scatter(a):
    N = 10
    print('原始数据:\n', a)
    density, edges = np.histogramdd(a, bins=[N,N,N], range=[(0,1), (0,1), (0,1)])
    density /= density.max()
    x = y = z = np.arange(N)
    d = np.meshgrid(x, y, z)

    fig = plt.figure(1, facecolor='w')
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(d[1], d[0], d[2], c='r', s=100*density, marker='o', depthshade=True)
    ax.set_xlabel(u'红色分量')
    ax.set_ylabel(u'绿色分量')
    ax.set_zlabel(u'蓝色分量')
    plt.title(u'图像颜色三维频数分布', fontsize=20)

    plt.figure(2, facecolor='w')
    den = density[density > 0]
    den = np.sort(den)[::-1]
    t = np.arange(len(den))
    plt.plot(t, den, 'r-', t, den, 'go', lw=2)
    plt.title(u'图像颜色频数分布', fontsize=18)
    plt.grid(True)

    plt.show()


if __name__ == '__main__':
    matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
    matplotlib.rcParams['axes.unicode_minus'] = False

    num_vq = 256 #256个像素,最后想降维到256个维度
    im = Image.open('Lena.png')     # flower2.png(200)/lena.png(50)  读取图片
    image = np.array(im).astype(np.float) / 255 # 将图像数据转化为array类型,方便后续的操作
    image = image[:, :, :3]#所有行、列、前三个维度,因为png有四个属性,RGBα,alpha为透明度,不需要,只用前三个维度信息即可
    image_v = image.reshape((-1, 3))#拉伸像素点,行不关心,列为3列===转换为二维数据,每一列均为一个维度的全部数据,RGB变成了3列,每行为一个像素点
    model = KMeans(num_vq)#通过Kmeans对每一行进行处理,也就是每个像素点进行处理,对每个像素点进行分类,像素点相似度的归类;创建聚类对象
    #每类都有一个中心像素点,将该中心像素点代替这一类,这里手动传入的是分成256个类别。不管像素点位置,只考虑相似与否是否为一类
    show_scatter(image_v)#画图

    N = image_v.shape[0]    # 图像像素总数
    # 选择足够多的样本(如1000个),计算聚类中心
    idx = np.random.randint(0, N, size=1000)#从图像中随机选取1000个像素点
    image_sample = image_v[idx]
    model.fit(image_sample)#将这1000个像素点去训练模型,聚类结果,从1000个像素点中找到最重要的256个像素点作为中心点
    c = model.predict(image_v)  # 将图像全部的像素点进行预测,看看图像中的所有像素点离这256个簇哪一个最近,把图像的所有像素点进行分类
    print('聚类结果:\n', c)
    print('聚类中心:\n', model.cluster_centers_)

    plt.figure(figsize=(15, 8), facecolor='w')
    plt.subplot(121)
    plt.axis('off')
    plt.title(u'原始图片', fontsize=18)
    plt.imshow(image)
    #plt.savefig('1.png')

    plt.subplot(122)
    vq_image = restore_image(model.cluster_centers_, c, image.shape)#聚类中心点、聚类结果、模型图像的形状   作为参数进行恢复图像
    plt.axis('off')
    plt.title(u'矢量量化后图片:%d色' % num_vq, fontsize=20)
    plt.imshow(vq_image)
    #plt.savefig('2.png')

    plt.tight_layout()
    plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_41264055/article/details/124895343