代码全部在jupyter中实现
一、 导包和导入相关数据
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.linear_model import LinearRegression,Ridge,Lasso
from sklearn.neighbors import KNeighborsRegressor
import sklearn.datasets as datasets
#人脸数据
faces = datasets.fetch_olivetti_faces()
查看数据:
faces
data = faces['images']
data.shape
随意查看一张人脸
index = np.random.randint(400,size =1)[0]
plt.imshow(data[index],cmap=plt.cm.gray)
二、进行人脸的图片分割
#上半张人脸
X = data[:,:32].reshape(400,-1)
#下半张人脸
y = data[:,32:].reshape(400,-1)
#导包,进行测试数据和训练数据的分割
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size = 10)
#10个位测试数据,test_size可以为int类型
进行简单的测试:
index = np.random.randint(390,size = 1)[0]
face_up = X_train[index].reshape(32,64)
face_down = y_train[index].reshape(32,64)
ax = plt.subplot(1,3,1)
ax.imshow(face_up,cmap = plt.cm.gray)
ax = plt.subplot(1,3,2)
ax.imshow(face_down,cmap = plt.cm.gray)
ax = plt.subplot(1,3,3)
ax.imshow(np.concatenate([face_up,face_down],axis = 0),cmap = 'gray')
三、训练和测试
四种算法 KNN,线性回归,岭回归,lasso
estimator = {}
estimator['KNN'] = KNeighborsRegressor(n_neighbors=5)
estimator['Lr'] = LinearRegression()
estimator['Ridge'] = Ridge(alpha=1)
estimator['Lasso'] = Lasso(alpha=1)
进行训练:
predict_ = {}
for key,model in estimator.items():
model.fit(X_train,y_train)
#预测:根据上半张脸预测下半张连
y_ = model.predict(X_test)
predict_[key] = y_
可视化,进行比较
#可视化
#10行,6列
plt.figure(figsize = (6*2,10*2))
for i in range(10):
#第一列
ax = plt.subplot(10,6,1 + i * 6)
face_up = X_test[i].reshape(32,64)
face_down = y_test[i].reshape(32,64)
ax.imshow(np.concatenate([face_up,face_down],axis = 0),cmap = 'gray')
ax.axis('off') #去除刻度
if i == 0:
ax.set_title('True') #添加标题
#第二列
ax = plt.subplot(10,6,2 + i * 6)
ax.imshow(face_up,cmap = 'gray')
if i == 0:
ax.set_title('Face_up')
#第三、四、五、六列
#预测人脸predict_
for j ,key in enumerate(predict_):
ax = plt.subplot(10,6,3 + j + i*6 )
y_ = predict_[key]
face_down_ = y_[i].reshape(32,64)
ax.imshow(np.concatenate([face_up,face_down_],axis = 0),cmap = 'gray')
ax.axis('off')
if i == 0:
ax.set_title(key)
如图为测试结果图:
希望提出建议进行改进。