用的https://github.com/endernewton/tf-faster-rcnn
endernewton版本tensorflow实现的faster-rcnn
原来demo.py:实现的是检测一张图片,然后对该图片的每一类检测结果,单独显示。
修改之后:从txt中读取要检测的图片名称,进行批量检测,并把所有类的检测结果都放到一张图上,然后保存到data/result里。
#!/usr/bin/env python
"""
https://blog.csdn.net/gusui7202/article/details/83239142
qhy。
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms
from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os, cv2
import argparse
from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
CLASSES = ('__background__', # always index 0
'normal bolt','normal bolt-2','normal bolt-3','shim losing','nut losing','nut losing-2','nut losing-3','nut directly loosening','nut directly loosening-2','nut directly loosening-3','nut directly loosening-4','pin loosening','pin closing','visible pin losing','visible pin losing-2','invisible pin losing','invisible pin losing-2')
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_40000.ckpt',)}
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
fi2=open('/home/omnisky/q/tf-faster-rcnn-master/work.txt','w')
def vis_detections(image_name, im, class_name, dets, thresh=0.5):#不用这个函数了。
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=1.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
ax.set_title(('{} detections with '
'p({} | box) >= {:.1f}').format(class_name, class_name,
thresh),
fontsize=14)
# plt.axis('off')
# plt.tight_layout()
# plt.draw()
# image_name=image_name.replace('jpg','png')
# plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name)
# print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name))
def demo(image_name, sess, net):
"""Detect object classes in an image using pre-computed object proposals."""
# Load the demo image
im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
im = cv2.imread(im_file)
# Detect all object classes and regress object bounds
timer = Timer()
timer.tic()
scores, boxes = im_detect(sess, net, im)
timer.toc()
print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))
# Visualize detections for each class
CONF_THRESH = 0.7
thresh=0.7
NMS_THRESH = 0.3
#打开图片
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal', alpha=0.5)
#对每一类的每一个目标,在图片上生成框
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # because we skipped background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes,
cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(dets, NMS_THRESH)
dets = dets[keep, :]
# vis_detections(image_name, im, cls, dets, thresh=CONF_THRESH)#这个函数注释掉了,用下面的。
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
continue
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=1.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(cls, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
plt.axis('off')
plt.tight_layout()
plt.draw()
image_name=image_name.replace('jpg','png')
plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name)
print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name))
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
choices=NETS.keys(), default='res101')
parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
choices=DATASETS.keys(), default='pascal_voc_0712')
args = parser.parse_args()
return args
if __name__ == '__main__':
cfg.TEST.HAS_RPN = True # Use RPN for proposals
args = parse_args()
# model path
demonet = args.demo_net
dataset = args.dataset
tfmodel = ('/home/omnisky/q/tf-faster-rcnn-master/output/res101/voc_2007_trainval/default/res101_faster_rcnn_iter_40000.ckpt')
if not os.path.isfile(tfmodel + '.meta'):
raise IOError(('{:s} not found.\nDid you download the proper networks from '
'our server and place them properly?').format(tfmodel + '.meta'))
# set config
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth=True
# init session
sess = tf.Session(config=tfconfig)
# load network
if demonet == 'vgg16':
net = vgg16()
elif demonet == 'res101':
net = resnetv1(num_layers=101)
else:
raise NotImplementedError
net.create_architecture("TEST", 18,
tag='default', anchor_scales=[8, 16, 32])
saver = tf.train.Saver()
saver.restore(sess, tfmodel)
print('Loaded network {:s}'.format(tfmodel))
#读取txt,循环检测。
fi=open('/home/omnisky/q/tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt')
txt=fi.readlines()
im_names = []
for line in txt:
line=line.strip('\n')
line=line.replace('\r','')
line=(line+'.jpg')
im_names.append(line)
print(im_names)
fi.close()
for im_name in im_names:
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
print('Demo for data/demo/{}'.format(im_name))
demo(im_name, sess, net)
plt.show()#建议注释掉,不然一次图片全部显示容易死机。