Tensorflow调用目标检测模型并显示
import numpy as np
import tensorflow as tf
import cv2 as cv
with tf.gfile.FastGFile('exported_model/frozen_inference_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
img = cv.imread('mouth_dataset/test/1_107.jpg')
inp = cv.resize(img, (300, 300))
inp = inp[:, :, [2, 1, 0]]
out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
sess.graph.get_tensor_by_name('detection_scores:0'),
sess.graph.get_tensor_by_name('detection_boxes:0'),
sess.graph.get_tensor_by_name('detection_classes:0')],
feed_dict={
'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})
num_detections = int(out[0][0])
for i in range(num_detections):
classId = int(out[3][0][i])
score = float(out[1][0][i])
bbox = [float(v) for v in out[2][0][i]]
if score > 0.3:
x = bbox[1] * img.shape[1]
y = bbox[0] * img.shape[0]
right = bbox[3] * img.shape[1]
bottom = bbox[2] * img.shape[0]
cv.rectangle(img, (int(x), int(y)), (int(right), int(bottom)), (255, 255, 0), thickness=1)
cv.imshow('TensorFlow', img)
cv.waitKey()
- 修改第6行exported_model/frozen_inference_graph.pb换成自己训练好的模型;
- 修改第17行img = cv.imread(‘mouth_dataset/test/1_107.jpg’),换成自己需要检测的图片;