Faster批量测试并保存\所有类一起显示\groundtrue显示

原来的功能只是单张图片并且按类显示出框,修改后批量测试并保存,所有类的检测结果绘制在一张图片上,且同时将groundtrue一起绘制出来。

#!/usr/bin/env python

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick

#qhy
#2018.10.26
# --------------------------------------------------------

"""
Demo script showing detections in sample images.

See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import xml.etree.ElementTree as ET
import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms

from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os, cv2
import argparse


from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

CLASSES = ('__background__',  # always index 0
                     'normal bolt','normal bolt-2','normal bolt-3','nut losing','nut losing-2','nut directly loosening','pin closing','visible pin losing','visible pin losing-2','invisible pin losing','invisible pin losing-2')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_100000.ckpt',)}
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
def vis_detections(image_name, im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=1.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
#    plt.axis('off')
#    plt.tight_layout()
#    plt.draw()
#    image_name=image_name.replace('jpg','png')
#    plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name)
#    print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name))

def demo(image_name, sess, net):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)
    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    # Visualize detections for each class
    CONF_THRESH = 0.7
    thresh=0.7
    NMS_THRESH = 0.3
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal', alpha=0.75)
###
    xml_path='/home/omnisky/q/tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/Annotations/'
    xml_name=image_name.replace('jpg','xml')
    xml_file=os.path.join(xml_path+xml_name)
    tree=ET.parse(xml_file)
    root=tree.getroot()
    for object in root.findall('object'):
        a=[]
        a.append(object.find('name').text)
        a.append(int(object.find('bndbox').find('xmin').text))
        a.append(int(object.find('bndbox').find('ymin').text))
        a.append(int(object.find('bndbox').find('xmax').text))
        a.append(int(object.find('bndbox').find('ymax').text))
        ax.add_patch(
                     plt.Rectangle((a[1],a[2]),a[3]-a[1],a[4]-a[2],fill=False,edgecolor='g', linewidth=2.5)
                                        )
        ax.text(a[1], a[2] - 5,'{:s}'.format(a[0]),bbox=dict(facecolor='yellow', alpha=0.5),fontsize=14, color='black')
###
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
#        vis_detections(image_name, im, cls, dets, thresh=CONF_THRESH)

        inds = np.where(dets[:, -1] >= thresh)[0]
        if len(inds) == 0:
            continue
        for i in inds:
            bbox = dets[i, :4]
            score = dets[i, -1]
    
            ax.add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='red', linewidth=2.5)
                )
            ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(cls, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
    plt.axis('off')
    plt.tight_layout()
    plt.draw()
    image_name=image_name.replace('jpg','png')
    plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name)
    print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name))

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='res101')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()

    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = ('/home/omnisky/q/tf-faster-rcnn-master/output/res101/voc_2007_trainval/default/res101_faster_rcnn_iter_100000.ckpt')


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST", 12,
                          tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))
    fi=open('/home/omnisky/q/tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt')

    txt=fi.readlines()
    im_names = []
    for line in txt:
        line=line.strip('\n')
        line=line.replace('\r','')
        line=(line+'.jpg')
        im_names.append(line)
    print(im_names)
    fi.close()
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(im_name, sess, net)
     
#    plt.show()

绿色框是真实标签,真实标签的名字用黄色框显示的。

猜你喜欢

转载自blog.csdn.net/gusui7202/article/details/83412943