Data columns (total 18 columns): CUST_ID 8950 non-null object BALANCE 8950 non-null float64 BALANCE_FREQUENCY 8950 non-null float64 PURCHASES 8950 non-null float64 ONEOFF_PURCHASES 8950 non-null float64 INSTALLMENTS_PURCHASES 8950 non-null float64 CASH_ADVANCE 8950 non-null float64 PURCHASES_FREQUENCY 8950 non-null float64 ONEOFF_PURCHASES_FREQUENCY 8950 non-null float64 PURCHASES_INSTALLMENTS_FREQUENCY 8950 non-null float64 CASH_ADVANCE_FREQUENCY 8950 non-null float64 CASH_ADVANCE_TRX 8950 non-null int64 PURCHASES_TRX 8950 non-null int64 CREDIT_LIMIT 8949 non-null float64 PAYMENTS 8950 non-null float64 MINIMUM_PAYMENTS 8637 non-null float64 PRC_FULL_PAYMENT 8950 non-null float64 TENURE 8950 non-null int64
一、加载库
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn.metrics import calinski_harabaz_score
import matplotlib as mpl
import matplotlib.pyplot as plt
# 设置字体为黑体,以支持中文显示。
mpl.rcParams["font.family"] = "SimHei"
# 设置在中文字体时,能够正常的显示负号(-)。
mpl.rcParams["axes.unicode_minus"] = False
二、 数据预处理
# 加载数据集
data = pd.read_csv(r"credit_card.csv",header=0)
#data.info()
# 查看是否含有异常值
#data.describe()
# 检查是否包含重复值
#ata.duplicated().any()
# 如果有重复值,可以这样去除重复值
# data.drop_duplicates(inplace=True)
data.drop(["CUST_ID","MINIMUM_PAYMENTS","CREDIT_LIMIT"],axis=1,inplace = True)
#data.head()
#data.info()
X = data
plt.scatter(X.values[:, 2], X.values[:, 5], marker='o')
plt.show()
三、调用KMeans方法
for k in range(2,8):
#实例化线性回归模型
km=KMeans(n_clusters=k,random_state=0)
#训练模型
result = km.fit_predict(X)
plt.figure(figsize=(12, 8))
plt.subplots(1, 1)
plt.scatter(X.values[:, 2], X.values[:, 5], c=result)
plt.scatter(km.cluster_centers_[:, 2], km.cluster_centers_[:, 5], marker="+", s=100)
score = calinski_harabaz_score(X, result)
print("For n_clusters =", k,
"The calinski_harabaz_score is :", score)
四、结果可视化
km=KMeans(n_clusters=5,random_state=0)
result = km.fit_predict(X)
plt.figure(figsize=(12, 8))
plt.scatter(X.values[:, 2], X.values[:, 5], c=result)
plt.scatter(km.cluster_centers_[:, 2], km.cluster_centers_[:, 5], marker="+", s=100)
plt.show()