Faster Rcnn 一个目标被识别成多个label (重复框问题)

在使用faster-rcnn检测目标时,出现这种问题很正常,对于作者的数据集来说,主要是识别一些物体,所以出现下面这种图的结果很正常。

å¾å°çç»æè¿æ¯å¾å¥½çï¼æ¯ç«æ¨¡åæ¯äººå®¶å·²ç»è®­ç»å¥½çï¼æ们åªæ¯æ¿æ¥è·ä¸ä¸ãå¯è½ä½ çç»æä¸å¼ ç§çåªæ¾ç¤ºä¸ä¸ªæ¡ï¼å¦ææ³ä¸ä¸ªå¾çéæ¾ç¤ºå¤ä¸ªæ¡ä¿®æ¹ä¸ä¸demo.py代ç å³å¯

但是当我们进行检测,有可能检测的是一个物体的状态,由于模型的拟合的不够好,一个物体的状态有可能会被检测成两种不同的状态,这样就需要我们对两种状态取一个最大值,然后保存。例如下图

该label标注的是normal,但在识别的时候出现了normal和tilt。只能说明模型的拟合程度不够好,并不是哪里出错了

解决方法其实很简单:

将所有的 box都保存起来,然后作nms,nms之后再进行可视化操作(就是画框框,上图相当于nms后把normal去掉了)

下面是我demo.py的代码。

在此之前,已经在原作者的基础上进行修改了,主要修改的方面有:使一张图中出现多个label 、 根据xml画出gt 、还有一个是计数方便数据统计

#!/usr/bin/env python
# ._*_.coding:utf-8
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.

See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import xml.dom.minidom as xmldom
import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms
from utils.timer import Timer
import tensorflow as tf
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse

from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

CLASSES = ('__background__',
           'tilt', 'miss', 'normal')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',), 'res101': ('res101_faster_rcnn_iter_100000.ckpt',)}
DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}


# 计算IOU交并比
def compute_IOU(rec1, rec2):
    """
    计算两个矩形框的交并比。
    :param rec1: (x0,y0,x1,y1)      (x0,y0)代表矩形左上的顶点,(x1,y1)代表矩形右下的顶点。下同。
    :param rec2: (x0,y0,x1,y1)
    :return: 交并比IOU.
    """
    left_column_max = max(rec1[0], rec2[0])
    right_column_min = min(rec1[2], rec2[2])
    up_row_max = max(rec1[1], rec2[1])
    down_row_min = min(rec1[3], rec2[3])
    # 两矩形无相交区域的情况
    if left_column_max >= right_column_min or down_row_min <= up_row_max:
        return 0
    # 两矩形有相交区域的情况
    else:
        S1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
        S2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
        S_cross = (down_row_min - up_row_max) * (right_column_min - left_column_max)
        return S_cross / (S1 + S2 - S_cross)


global count
count = {
    "tilt": 0,
    "normal": 0,
    "miss": 0
}
global count1
count1 = {
    "tilt": 0,
    "normal": 0,
    "miss": 0
}
global iou_sum
iou_sum = {
    "sum": 0.0,
    "time": 0.0
}
global error
error = {
    "cuowu": '000'
}


def vis_detections(im, class_name, dets, image_name, ax, im_file, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    # 查找的图片地址
    # 怎么获取图片的GB呢???
    # print(im_file)

    inpath = '/home/share/liubo/tf-faster-rcnn1/data/VOCdevkit2007/VOC2007/Annotations/' + image_name[0:-4] + '.xml'
    # inpath = inpath.decode('GB2312').encode('utf-8')
    # uipath = unicode(inpath, "GB2312")
    domobj = xmldom.parse(inpath)
    elementobj = domobj.documentElement
    name = elementobj.getElementsByTagName("name")
    xmin = elementobj.getElementsByTagName("xmin")
    ymin = elementobj.getElementsByTagName("ymin")
    xmax = elementobj.getElementsByTagName("xmax")
    ymax = elementobj.getElementsByTagName("ymax")
    size = len(name)
    if im_name == '300313.JPG':
        print(name[0].firstChild.data)
    for i in range(size):
        ax.add_patch(plt.Rectangle((int(xmin[i].firstChild.data), int(ymin[i].firstChild.data)),
                                   int(xmax[i].firstChild.data) - int(xmin[i].firstChild.data),
                                   int(ymax[i].firstChild.data) - int(ymin[i].firstChild.data), fill=False,
                                   edgecolor='yellow', linewidth=3.5)
                     )
        ax.text(int(xmax[i].firstChild.data), int(ymin[i].firstChild.data) - 2,
                '{:s}'.format(name[i].firstChild.data),
                bbox=dict(facecolor='white', alpha=0.5),
                fontsize=20, color='black')
        # count1[name[i].firstChild.data] += 1

    # im = im[:, :, (2, 1, 0)]
    # fig, ax = plt.subplots(figsize=(40, 40))
    # ax.imshow(im, aspect='equal')

    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]
        print(class_name, score)
        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
        )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
        r1 = (bbox[0], bbox[1], bbox[2], bbox[3])
        for j in range(size):
            r2 = (int(xmin[j].firstChild.data), int(ymin[j].firstChild.data), int(xmax[j].firstChild.data),
                  int(ymax[j].firstChild.data))
            IOU = compute_IOU(r1, r2)
            if IOU >= 0.5 and name[j].firstChild.data == class_name:
                count[class_name] += 1
                iou_sum['sum'] += IOU
            if IOU >= 0.5 and name[j].firstChild.data != class_name:
                error['cuowu'] = error['cuowu'] + '\n' + image_name[0:-4]

            # iou_sum += IOU


# plt.axis('off')
#  plt.tight_layout()
#  plt.draw()


def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)
    # print("------------------------------")
    # print(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    iou_sum['time'] += timer.total_time
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.0
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(40, 40))
    ax.imshow(im, aspect='equal')
    all_dets = np.empty((0, 5), np.float32)
    all_cls = np.empty((0, 6))
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1  # because we skipped background
        cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        inds = np.where(dets[:, -1] >= 0.8)[0]
        dets = dets[inds, :]
        if len(dets) == 0:
            continue
        all_dets = np.append(all_dets, dets, axis=0)
        for i in dets:
            all_cls = np.vstack((all_cls, np.hstack((i, np.array([cls_ind])))))
        # all_cls = np.hstack((all_dets, np.array([[cls_ind]])))
        # all_cls = np.append(all_cls, np.array([[cls_ind]]), axis=0)
    # print('pre nms ************************')
    # for i in all_dets:
    #    print(i)
    keep1 = nms(all_dets, NMS_THRESH)
    all_dets = all_dets[keep1, :]
    all_cls = all_cls.reshape(-1, 6)
    all_cls = all_cls[keep1, :]
    for i in np.arange(len(all_cls)):
        cls_index = int(all_cls[i][5])
        vis_detections(im, CLASSES[cls_index], all_dets[i].reshape(-1, 5), image_name, ax, im_file,
                       thresh=CONF_THRESH)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()


def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='res101')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()

    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, 'voc_2007_trainval', 'default',
                           NETS[demonet][0])

    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST", 4,
                            tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))

    # im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #             '001763.jpg', '004545.jpg']
    im_names = os.listdir('/home/share/liubo/tf-faster-rcnn1/data/demo')
    cnt = 1
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ {:s}'.format(str(cnt)))
        cnt = cnt + 1
        print('Demo for data/demo/{}'.format(im_name))
        demo(sess, net, im_name)
        inpath = '/home/share/liubo/tf-faster-rcnn1/data/VOCdevkit2007/VOC2007/Annotations/' + im_name[0:-4] + '.xml'
        domobj = xmldom.parse(inpath)
        elementobj = domobj.documentElement
        name = elementobj.getElementsByTagName("name")
        size = len(name)
        for i in range(size):
            count1[name[i].firstChild.data] += 1
        plt.savefig("./data/demotest/" + im_name)

    print("测试结果:")
    for key, value in count.items():
        print(key + ":" + str(value))
    print("真实结果:")
    for key, value in count1.items():
        print(key + ":" + str(value))
    obj_sum = count['tilt'] + count['normal'] + count['miss'] + 0.0
    print('tilt accuracy rate:' + format(count['tilt'] / count1['tilt'], '.3f'))
    print('miss accuracy rate:' + format(count['miss'] / count1['miss'], '.3f'))
    print('normal accuracy rate:' + format(count['normal'] / count1['normal'], '.3f'))
    print(
        'tilt+miss accuracy rate:' + format((count['tilt'] + count['miss']) / (count1['tilt'] + count1['miss']), '.3f'))
    print('total accuracy rate:' + format(
        (count['tilt'] + count['miss'] + count['normal']) / (count1['tilt'] + count1['miss'] + count1['normal']),
        '.3f'))
    print(error)
# plt.show()
# print("缺销照片数:%d" %count)

主要修改的是 函数 def demo(sess, net, image_name):

# 182行
all_dets = np.empty((0, 5), np.float32)
all_cls = np.empty((0, 6))

all_dets:所有的目标框,前4个是box坐标,最后一个是得分

all_cls:所有的目标框+类别,前4个是box坐标,第5个是得分,第6个是类别

# 190行
keep = nms(dets, NMS_THRESH)
dets = dets[keep, :]
inds = np.where(dets[:, -1] >= 0.8)[0]
dets = dets[inds, :]

这里就是先对该图片的cls类别进行nms,保留一些框,然后筛选得分>=0.8的保存起来

# 196行
all_dets = np.append(all_dets, dets, axis=0)
for i in dets:
   all_cls = np.vstack((all_cls, np.hstack((i, np.array([cls_ind])))))

前面就是正常的一个拼接,后面用for循环因为,dets有可能不止一个,直接拼接有啥方法我还不会(numpy掌握的不是很好),就是在每个det的后面加上一列类别(cls_ind),然后再和all_cls进行vstack拼接。

这里要说一下,为什么多加一个all_cls,而不是再all_dets的后面直接加上一列,因为再nms的时候用的是5列,如果这样修改,还需要修改nms的代码,比较麻烦,所以增加一个array

这样一个过程下来,就得到了所有的目标框(得分大于0.8),

# 205行    
keep1 = nms(all_dets, NMS_THRESH)
all_dets = all_dets[keep1, :]
all_cls = all_cls.reshape(-1, 6)
all_cls = all_cls[keep1, :]

然后对all_dets进行nms,去除重复框,这里就是筛选了,在同一个目标上得到多种不同的结果,只保留最大的得分所对应的框

这里要注意的是all_cls要进行reshape,为啥呢。。因为我测试出来的,这个数据哈有点怪,有的多维有的一维

# 209行
for i in np.arange(len(all_cls)):
    ls_index = int(all_cls[i][5])
    vis_detections(im, CLASSES[cls_index], all_dets[i].reshape(-1, 5), image_name, ax, im_file,thresh=CONF_THRESH)

这点就简单了,对每一个目标进行标注,传进去对应的参数就行

总的来说吧,搞了半个下午,思路很简单,就是将所有的保存起来,难就难在数组的操作,有的拼接不起来,需要测试是什么类型的,然后改数组类型,这个步骤比较繁琐,而且demo.py进行测试也慢。

我运行起来是没什么问题的,如果有问题欢迎一起来讨论。

这里感谢YJL同志给出的思路

猜你喜欢

转载自blog.csdn.net/qq_33193309/article/details/100105963