遍历一个文件下的所有图片,进行单张预测,并复制到相应的文件夹
import caffe
#import lmdb
import numpy as np
import cv2
from caffe.proto import caffe_pb2
import os
import sys
caffe.set_mode_gpu()
def dirlist(path, allfile):
filelist = os.listdir(path)
for filename in filelist:
filepath = os.path.join(path, filename)
if os.path.isdir(filepath):
dirlist(filepath, allfile)
else:
allfile.append(filepath)
return allfile
# sys.setrecursionlimit(1000000)
def is_bgr_img(img):
bools = True
try:
a, b, c = img.shape
except AttributeError:
bools = False
return bools
# load caffe
root = 'D:/stomach_raw_data/deepid/' # 根目录
deploy = root + 'deploy_all.prototxt' # deploy文件
caffe_model = root + 'id_128_net_iter_1695000.caffemodel' # 训练好的 caffemodel
labels_filename = root + 'labels.txt' # 类别名称文件,将数字标签转换回类别名称
# 加载model和network
net = caffe.Net(deploy, caffe_model, caffe.TEST)
# 设定图片的shape格式(1,3,28,28)依次为数量,通道,高,宽
transformer = caffe.io.Transformer({'data': net.blobs['data_1'].data.shape})
# 改变颜色通道,由RGB转成BGR
transformer.set_transpose('data', (2, 0, 1))
#减去均值,前面训练模型时没有减均值,这儿就不用
# transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))
# transformer.set_raw_scale('data', 255) # 缩放到【0,255】之间
# transformer.set_channel_swap('data', (2,1,0)) #交换通道,将图片由RGB变为BGR
labels = np.loadtxt(labels_filename, str, delimiter='\t')
dirs = ['0_CA', '1_FV', '2_GB', '3_GA', '4_SV', '5_PY', '6_OT','7_IV']
imgnames = dirlist('D:\\2D', [])
path ='D:/sto_img_1695000/'
temp = imgnames[0]
print(temp.split('\\')[-2].split('_')[0])
print(temp)
t = 0
all = 0
acc = 0
a_pro = 1
for imgname in imgnames:
image = cv2.imread(imgname)
temp = imgname
try:
image.shape
except AttributeError:
print(imgname)
os.remove(imgname)
continue
# imgx = image/255
net.blobs['data_1'].data[...] = transformer.preprocess('data', image)
t1 = cv2.getTickCount()
for i in range(1):
out = net.forward()
t += (cv2.getTickCount() - t1) * 1000 / cv2.getTickFrequency()
prob = net.blobs['softmax'].data[0].flatten()
#print(prob)
order = prob.argsort()[-1]
prob_max = prob[order]
print('max = %f,class = %d,all = %d\n'%(prob_max,order,all))
if prob_max > 0.70:
imgname = temp.split('\\')[-1]
imgpath = path + dirs[order]
if not os.path.exists(imgpath):
os.mkdir(imgpath)
cv2.imwrite(imgpath+'/'+imgname, image)
else:
imgname = temp.split('\\')[-1]
imgpath = path + 'unkown'
if not os.path.exists(imgpath):
os.mkdir(imgpath)
cv2.imwrite(imgpath+'/'+imgname, image)
cv2.imshow('cv2', image)
k = cv2.waitKey(1)
if k == 27:
break
if k == 32:
cv2.waitKey()
cv2.destroyAllWindows()