torchvision.ops.nms实现NMS

nms原理:
当目标检测模型对一个目标有多个检测框时,需要滤掉多余的框,留下最接近真实目标的框。

在这里插入图片描述

步骤是这样的:
1.先把目标框初筛一波,比如设阈值为0.25, 把预测概率 < 0.25的目标框滤掉。
2.把 每个类别的 目标框 按预测概率从大到小排序。
3.每个类别的目标框 两两计算 IOU,当IOU > 阈值时,说明这两个目标框高度重合,没必要都留着,留概率较大的那个。注意这里是同类别的框计算IOU,不同类别的即使高度重合也不滤掉。

这样每个类别就滤掉了多余的框。因为已经按照预测概率从大到小排序,所以两两计算IOU时优先计算的是概率较大的框。

下面看下torchvision.ops.nms的用法,
这里假设有N个目标框。
传入参数boxes为Tensor,shape为[N,4], 每个box坐标格式是(x1,y1,x2,y2),即左上角右下角坐标。
如果你的目标框为(x,y,w,h), 需要做格式转换。
score: Tensor, shape为[N], 每个目标框检测的概率,如果是COCO,预测了80个类别的概率,就把最大的概率取出来。
iou阈值:float型,每个类别内两两目标框IOU > 这个阈值时,扔掉概率小的那个。

用它之前先用概率阈值初筛一波box, 把预测概率很低的box去掉。
需要把box的坐标转为左上角右下角坐标格式。
把box按概率从大到小排序。
提取每个box的预测概率。

torchvision.ops.nms的注释:
```python
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
    """
    Performs non-maximum suppression (NMS) on the boxes according
    to their intersection-over-union (IoU).

    NMS iteratively removes lower scoring boxes which have an
    IoU greater than iou_threshold with another (higher scoring)
    box.

    If multiple boxes have the exact same score and satisfy the IoU
    criterion with respect to a reference box, the selected box is
    not guaranteed to be the same between CPU and GPU. This is similar
    to the behavior of argsort in PyTorch when repeated values are present.

    Args:
        boxes (Tensor[N, 4])): boxes to perform NMS on. They
            are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
            ``0 <= y1 < y2``.
        scores (Tensor[N]): scores for each one of the boxes
        iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold

    Returns:
        Tensor: int64 tensor with the indices of the elements that have been kept
        by NMS, sorted in decreasing order of scores
    """

下面以yolov8为例,说明如何用torchvision.ops.nms来计算nms, 过滤掉多余的目标框。

这里yolov8的prediction是(1,116,5460), 其中5460是anchor的数量,
116的前4个是目标框坐标(x,y,w,h), 中间80个是COCO数据集中80个类的预测概率,
后面32个是mask coeff, 分割mask用的,nms这里不用。只用到前面84个。

正常情况下每个类别都要算一次nms, 这里用了batched nms, 用了小技巧把所有类别的nms一起计算。

代码为了阅读简洁做了修改。

def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nc=0,  # number of classes (optional)
        max_time_img=0.05,
        max_nms=30000,
        max_wh=7680,
):
  #prediction:(1,116,5460) 
    
    bs = prediction.shape[0]  # batch size  #这里只有一张图片,所以是1
    nc = 80  # 80个类别
    nm = 32  #最后32是mask coeff,用于目标分割的
    mi = 84  # mask start index
    #4~84是80个类别概率所在的位置,每个anchor取出最大的预测概率,用0.25的阈值过滤一波。
    #最大的都滤掉的话,说明这个anchor处没有预测到目标,得到的是boolean(1,5460), anchor数量
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates

    #用来保存nms结果
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    
    for xi, x in enumerate(prediction):  # image index, image inference #每个图片计算一次
        #取出刚刚用阈值滤掉一波剩下的anchor,input x是(116,5460),滤掉后剩27个anchor
        #x transpose后变为(5460,116),过滤后为(27,116)
        x = x.transpose(0, -1)[xc[xi]]  # confidence
   
        # 把前面说的box坐标,80个类别的概率和mask分开
        #分别是(27,4), (27,80), (27,32)
        box, cls, mask = x.split((4, nc, nm), 1)
        #torchvision的nms要求是(x1,y1,x2,y2)格式
        box = xywh2xyxy(box)  # center_x, center_y, width, height) to (x1, y1, x2, y2)
        
        #cls:(27,80), 取每行最大值,得到(27,1). conf是最大value, j是最大value对应的index, 也就是class id
        conf, j = cls.max(1, keepdim=True)
        
        #上次已经用conf_thres筛过一次了,这次>conf_thres应该全是true
        #cat之后是(27,38), box:(27,4),conf:(27,1),j:(27,1), mask:(27,32)
        x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        #把x的所有行按conf从大到小排序,(27,38)
        x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        #如果你不想每个类别都做一次nms,而是所有类别一起做nms
        #就需要把不同类别的目标框尽量没有重合,不至于把不同类别的IOU大的目标框滤掉
        #先用每个类别id乘一个很大的数,作为offset,把每个类别的box坐标都加上相应的offset,这是batched nms
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS

        i = i[:max_det]  # limit detections
        
        output[xi] = x[i]  #取出NMS过滤剩下的prediction
        
    return output

猜你喜欢

转载自blog.csdn.net/level_code/article/details/131245680
NMS