教你只需三步实现基于SSD算法的目标检测器
先看效果
(舍友乱入哈哈哈)
第一步——安装依赖库
库名称 | 版本(我的版本) |
---|---|
tensorflow | 1.14.0 |
opencv | 3.4.2 |
numpy | 1.16.3 |
matplotlib | 3.0.3 |
安装教程:
在终端输入
pip install 库名称==版本号 --user
如:
pip install tensorflow==1.14.0 --user
注意事项:
版本号不一定需要严格按照我的来,但是如果出现了报错AttributeError: module ‘xxxx’ has no attribute ‘xxxxx’,很有可能就是版本不一致的问题;
第二步——下载源码并解压
源码地址:
https://github.com/balancap/SSD-Tensorflow
点击这里下载并解压:
解压模型文件:
打开checkpoints文件夹,如图:
将这个压缩包解压:
第三步——复制调用模型程序
测试图片:
这一步需要创建一个demo_test.py文件,并将下面的代码复制到这个文件里:
# demo_test.py
from notebooks import visualization
from preprocessing import ssd_vgg_preprocessing
from nets import ssd_vgg_300, ssd_common, np_methods
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import os
import math
import random
import numpy as np
import tensorflow as tf
import cv2
slim = tf.contrib.slim
gpu_options = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
isess = tf.InteractiveSession(config=config)
net_shape = (300, 300)
data_format = 'NHWC'
img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)
image_4d = tf.expand_dims(image_pre, 0)
reuse = True if 'ssd_net' in locals() else None
ssd_net = ssd_vgg_300.SSDNet()
with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
predictions, localisations, _, _ = ssd_net.net(
image_4d, is_training=False, reuse=reuse)
ckpt_filename = './checkpoints/ssd_300_vgg.ckpt'
isess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(isess, ckpt_filename)
ssd_anchors = ssd_net.anchors(net_shape)
def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)):
rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
feed_dict={img_input: img})
rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
rpredictions, rlocalisations, ssd_anchors,
select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)
rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
rclasses, rscores, rbboxes = np_methods.bboxes_sort(
rclasses, rscores, rbboxes, top_k=400)
rclasses, rscores, rbboxes = np_methods.bboxes_nms(
rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
return rclasses, rscores, rbboxes
image_path = './cat.jpg' # 图片路径
img = mpimg.imread(image_path)
rclasses, rscores, rbboxes = process_image(img) # 这里传入图片
# labeled_img = visualization.bboxes_draw_on_img(
# img, rclasses, rscores, rbboxes, visualization.colors_plasma) # 返回标注图片
visualization.plt_bboxes(img, rclasses, rscores, rbboxes) # 展示(plt)标注图片
如图:
最后在这里修改需要测试的图片的路径,就可以啦:
测试视频:
我们创建一个detector.py程序:
# detector.pyimport cv2
from demo_test import process_image
from notebooks import visualization
class Detertor(object):
def __init__(self, camera_index=0):
self.camera_index = camera_index
def Catch_Video(self, window_name='Detertor'):
cv2.namedWindow(window_name)
cap = cv2.VideoCapture(self.camera_index)
while cap.isOpened():
catch, frame = cap.read() # 读取每一帧图片
if not catch:
raise Exception('Check if the camera if on.')
break
rclasses, rscores, rbboxes = process_image(frame) # 这里传入图片
labeled_img = visualization.bboxes_draw_on_img(
frame, rclasses, rscores, rbboxes, visualization.colors_plasma)
cv2.imshow(window_name, labeled_img)
c = cv2.waitKey(10)
if c & 0xFF == ord('q'):
# 按q退出
break
if cv2.getWindowProperty(window_name, cv2.WND_PROP_AUTOSIZE) < 1:
# 点x退出
break
# 释放摄像头
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
detect = Detertor()
detect.Catch_Video()
大功告成
看一下效果:
图片测试效果:
视频测试效果:
这样就大功告成啦~
原理和论文及源码解析请看我的另外两篇博客:
SSD目标检测算法详解 (一)论文讲解
SSD目标检测算法详解 (二)代码详解
如果对你有帮助的话,记得点赞关注哦