下面是python代码:
# coding=UTF-8
"""
使用训练好的caffe模型预测手写体程序
"""
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
caffe_root="/home/pcb/caffe/" #设置caffee的目录
sys.path.insert(0,caffe_root+"python")
import caffe
#指定LetNet的网络定义模型
Model_file="/home/pcb/caffe/examples/mnist/lenet.prototxt"
#加载训练好的Model模型
Pretrained="/home/pcb/caffe/examples/mnist/lenet_iter_10000.caffemodel"
#测试图片路径
Image_file="/home/pcb/caffe/data/mnist/train_13.bmp"
#caffe接口载入文件
input_image=caffe.io.load_image(Image_file,color=False)
#载入LeNet分类器
net=caffe.Classifier(Model_file,Pretrained)
prediction=net.predict([input_image],oversample=False)
caffe.set_mode_cpu() #设置为CPU模式
print "predicted calss:",prediction[0].argmax()
最终预测结果为:
选择了一个6的手写体图像,然后最后分类结果也是6,这样就正确分类了!