文章目录
将array格式的图像保存至路径中
灰度数字图像是每个像素只有一个采样颜色的图像。这类图像通常显示为从最暗黑色到最亮的白色的灰度
def save_image(im, i):
# 对图像进行反相处理
im = 255 - im
# 转换数组的类型
a = im.astype(np.uint8)
output_path = '.\\HandWritten'
# 判断路径是否存在
if not os.path.exists(output_path):
# 如果路径不存在,则创建对应路径
os.mkdir(output_path)
# array()变换的相反操作可以使用PIL的fromarray()完成,如im = Image.fromarray(im)
# Image.save—保存路径
Image.fromarray(a).save(output_path + ('\\%d.png' % i))
案例
#!/usr/bin/python
# -*- coding:utf-8 -*-
import numpy as np
from sklearn import svm
import matplotlib.colors
import matplotlib.pyplot as plt
from PIL import Image
import warnings
from sklearn.metrics import accuracy_score
import pandas as pd
import os
import csv
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from time import time
from pprint import pprint
def save_image(im, i):
im = 255 - im
a = im.astype(np.uint8)
output_path = '.\\HandWritten'
if not os.path.exists(output_path):
os.mkdir(output_path)
Image.fromarray(a).save(output_path + ('\\%d.png' % i))
def save_result(model):
data_test_hat = model.predict(data_test)
with open('Prediction.csv', 'wb') as f:
writer = csv.writer(f)
writer.writerow(['ImageId', 'Label'])
for i, d in enumerate(data_test_hat):
writer.writerow([i, d])
# writer.writerows(zip(np.arange(1, len(data_test_hat) + 1), data_test_hat))
if __name__ == "__main__":
# 消除警告
warnings.filterwarnings(action='ignore')
classifier_type = 'RF'
print('载入训练数据...')
t = time()
data = pd.read_csv('.\\MNIST.train.csv', header=0, dtype=np.int)
print('载入完成,耗时%f秒' % (time() - t))
y = data['label'].values
x = data.values[:, 1:]
print('图片个数:%d,图片像素数目:%d' % x.shape)
images = x.reshape(-1, 28, 28)
y = y.ravel()
print('载入测试数据...')
t = time()
data_test = pd.read_csv('.\\MNIST.test.csv', header=0, dtype=np.int)
data_test = data_test.values
images_test_result = data_test.reshape(-1, 28, 28)
print('载入完成,耗时%f秒' % (time() - t))
np.random.seed(0)
x, x_test, y, y_test = train_test_split(x, y, train_size=0.8, random_state=1)
images = x.reshape(-1, 28, 28)
images_test = x_test.reshape(-1, 28, 28)
print(x.shape, x_test.shape)
matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(15, 9), facecolor='w')
for index, image in enumerate(images[:16]):
plt.subplot(4, 8, index + 1)
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title(u'训练图片: %i' % y[index])
for index, image in enumerate(images_test_result[:16]):
plt.subplot(4, 8, index + 17)
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
save_image(image.copy(), index)
plt.title(u'测试图片')
plt.tight_layout()
plt.show()
# SVM
if classifier_type == 'SVM':
params = {
'C':np.logspace(1, 4, 4, base=10), 'gamma':np.logspace(-10, -2, 9, base=10)}
clf = svm.SVC(kernel='rbf')
model = GridSearchCV(clf, param_grid=params, cv=3)
# model = svm.SVC(C=1000, kernel='rbf', gamma=1e-10)
print('SVM开始训练...')
t = time()
model.fit(x, y)
t = time() - t
print('SVM训练结束,耗时%d分钟%.3f秒' % (int(t/60), t - 60*int(t/60)))
print ('最优分类器:', model.best_estimator_)
print ('最优参数:\t', model.best_params_)
print ('model.cv_results_ =',model.cv_results_)
t = time()
y_hat = model.predict(x)
t = time() - t
print('SVM训练集准确率:%.3f%%,耗时%d分钟%.3f秒' % (accuracy_score(y, y_hat)*100, int(t/60), t - 60*int(t/60)))
t = time()
y_test_hat = model.predict(x_test)
t = time() - t
print ('SVM测试集准确率:%.3f%%,耗时%d分钟%.3f秒' % (accuracy_score(y_test, y_test_hat)*100, int(t/60), t - 60*int(t/60)))
save_result(model)
elif classifier_type == 'RF':
rfc = RandomForestClassifier(100, criterion='gini', min_samples_split=2,
min_impurity_split=1e-10, bootstrap=True, oob_score=True)
print('随机森林开始训练...')
t = time()
rfc.fit(x, y)
t = time() - t
print('随机森林训练结束,耗时%d分钟%.3f秒' % (int(t/60), t - 60*int(t/60)))
print('OOB准确率:%.3f%%' % (rfc.oob_score_*100))
t = time()
y_hat = rfc.predict(x)
t = time() - t
print('随机森林训练集准确率:%.3f%%,预测耗时:%d秒' % (accuracy_score(y, y_hat)*100, t))
t = time()
y_test_hat = rfc.predict(x_test)
t = time() - t
print('随机森林测试集准确率:%.3f%%,预测耗时:%d秒' % (accuracy_score(y_test, y_test_hat)*100, t))
err = (y_test != y_test_hat)
err_images = images_test[err]
err_y_hat = y_test_hat[err]
err_y = y_test[err]
print(err_y_hat)
print(err_y)
plt.figure(figsize=(10, 8), facecolor='w')
for index, image in enumerate(err_images):
if index >= 12:
break
plt.subplot(3, 4, index + 1)
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title(u'错分为:%i,真实值:%i' % (err_y_hat[index], err_y[index]))
plt.suptitle(u'数字图片手写体识别:分类器%s' % classifier_type, fontsize=18)
plt.tight_layout(rect=(0, 0, 1, 0.95))
plt.show()
载入训练数据...
载入完成,耗时1.746520秒
图片个数:42000,图片像素数目:784
载入测试数据...
载入完成,耗时1.179636秒
(33600, 784) (8400, 784)
随机森林开始训练...
随机森林训练结束,耗时0分钟18.557秒
OOB准确率:95.821%
随机森林训练集准确率:100.000%,预测耗时:1秒
随机森林测试集准确率:96.512%,预测耗时:0秒
[2 6 0 7 8 7 0 7 5 0 2 2 9 3 9 8 2 8 7 0 4 9 3 9 2 8 9 7 4 4 1 5 8 7 5 3 2
4 0 1 8 7 3 2 6 5 8 4 9 2 7 8 5 2 4 9 9 6 8 2 5 3 9 2 5 1 4 6 2 8 8 0 8 2
2 1 9 8 4 9 3 3 2 8 7 6 8 9 7 3 5 3 1 2 3 9 9 3 9 8 4 4 7 8 3 3 7 3 4 4 0
9 9 1 7 4 9 5 2 8 8 3 5 5 8 5 1 3 6 2 7 7 3 6 4 3 4 5 0 7 4 9 5 1 4 3 3 5
8 7 9 2 0 8 3 3 2 6 5 9 9 9 8 1 3 7 1 5 5 3 4 1 2 9 5 2 8 3 1 4 4 3 2 8 3
4 4 3 9 2 5 7 1 7 6 8 0 5 9 5 6 5 0 8 8 7 0 6 4 8 7 9 4 3 2 4 4 2 8 6 3 1
9 2 9 6 2 8 9 5 8 4 4 0 0 2 2 6 9 7 9 4 5 0 6 2 3 6 5 9 9 9 2 5 2 9 9 8 8
5 4 2 3 9 3 7 9 0 5 0 9 5 3 2 4 3 9 8 8 3 9 5 2 7 9 2 8 5 8 5 4 5 8]
[4 5 6 4 2 2 5 2 9 7 7 3 7 9 4 3 7 5 3 5 6 4 5 3 7 2 4 4 6 8 4 8 9 8 3 5 3
0 5 8 2 3 8 9 0 6 9 8 5 3 3 2 9 1 9 8 5 5 3 7 3 5 4 3 6 8 2 4 3 1 3 3 5 3
4 8 8 2 7 4 9 5 4 9 3 2 3 8 2 8 3 5 8 3 9 7 3 2 7 9 2 8 9 4 9 1 0 9 6 9 9
4 4 7 9 9 4 8 7 3 5 5 9 3 6 9 8 5 4 7 5 3 5 5 9 9 9 3 2 5 8 4 6 8 9 5 9 0
3 3 2 7 8 4 9 5 1 4 6 3 4 8 9 5 1 9 7 8 3 1 9 8 3 7 3 7 3 8 3 9 9 1 3 3 2
9 7 5 7 0 3 2 9 9 8 6 6 8 4 3 4 3 9 2 0 9 5 5 7 9 2 3 9 9 1 9 7 1 9 8 9 8
4 7 7 0 1 7 7 3 9 3 9 4 9 3 3 2 4 9 4 8 1 5 5 3 8 4 1 7 7 7 3 8 3 7 5 3 0
3 7 7 2 7 8 4 4 2 9 3 4 8 1 8 9 5 7 5 3 1 7 8 3 9 8 3 2 6 5 6 8 8 3]
载入训练数据...
载入完成,耗时1.661703秒
图片个数:42000,图片像素数目:784
载入测试数据...
载入完成,耗时1.106137秒
(33600, 784) (8400, 784)
SVM开始训练...
SVM训练结束,耗时2分钟21.074秒
SVM训练集准确率:94.676%,耗时4分钟13.307秒
SVM测试集准确率:93.810%,耗时1分钟3.327秒