Faster批量测试且所有类检测结果都显示在一张图上。

用的https://github.com/endernewton/tf-faster-rcnn

endernewton版本tensorflow实现的faster-rcnn

原来demo.py:实现的是检测一张图片,然后对该图片的每一类检测结果,单独显示。

修改之后:从txt中读取要检测的图片名称,进行批量检测,并把所有类的检测结果都放到一张图上,然后保存到data/result里。


 

#!/usr/bin/env python

"""
https://blog.csdn.net/gusui7202/article/details/83239142
qhy。
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms

from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os, cv2
import argparse


from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

CLASSES = ('__background__',  # always index 0
                     'normal bolt','normal bolt-2','normal bolt-3','shim losing','nut losing','nut losing-2','nut losing-3','nut directly loosening','nut directly loosening-2','nut directly loosening-3','nut directly loosening-4','pin loosening','pin closing','visible pin losing','visible pin losing-2','invisible pin losing','invisible pin losing-2')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_40000.ckpt',)}
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
fi2=open('/home/omnisky/q/tf-faster-rcnn-master/work.txt','w')
def vis_detections(image_name, im, class_name, dets, thresh=0.5):#不用这个函数了。
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=1.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
#    plt.axis('off')
#    plt.tight_layout()
#    plt.draw()
#    image_name=image_name.replace('jpg','png')
#    plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name)
#    print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name))

def demo(image_name, sess, net):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)
    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    # Visualize detections for each class
    CONF_THRESH = 0.7
    thresh=0.7
    NMS_THRESH = 0.3
#打开图片
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal', alpha=0.5)
#对每一类的每一个目标,在图片上生成框
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
#        vis_detections(image_name, im, cls, dets, thresh=CONF_THRESH)#这个函数注释掉了,用下面的。

        inds = np.where(dets[:, -1] >= thresh)[0]
        if len(inds) == 0:
            continue
        for i in inds:
            bbox = dets[i, :4]
            score = dets[i, -1]
    
            ax.add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='red', linewidth=1.5)
                )
            ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(cls, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
    plt.axis('off')
    plt.tight_layout()
    plt.draw()
    image_name=image_name.replace('jpg','png')
    plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name)
    print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name))

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='res101')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()

    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = ('/home/omnisky/q/tf-faster-rcnn-master/output/res101/voc_2007_trainval/default/res101_faster_rcnn_iter_40000.ckpt')


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST", 18,
                          tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))
#读取txt,循环检测。
    fi=open('/home/omnisky/q/tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt')

    txt=fi.readlines()
    im_names = []
    for line in txt:
        line=line.strip('\n')
        line=line.replace('\r','')
        line=(line+'.jpg')
        im_names.append(line)
    print(im_names)
    fi.close()
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(im_name, sess, net)
     
    plt.show()#建议注释掉,不然一次图片全部显示容易死机。

猜你喜欢

转载自blog.csdn.net/gusui7202/article/details/83240212