本文参考自: 原文地址
非极大值抑制(NMS),其思想为:
对于重叠的候选框,若大于规定阈值(某一提前设定的置信度),则删除;低于阈值则保留。
对于无重叠的候选框,都保留。
IOU定义:两个边界框的交集部分除以它们的并集。
注:置信度定义(YOLO):置信度分数反应的是一个边界框是否包含对象并且预测对象的精确度
非极大值抑制的流程:
1.根据置信度的分进行排序
2.选择置信度最高的边界框添加到最终的输出列表中,并将其从边界框列表中删除
3.计算所有边界框的面积
4.计算置信度最高的边界框与其他便捷而框的IOU
5.删除大于阈值的边界框
6.重复上述过程,直到边界框列表为空
import cv2
import numpy as np
def nms(bounding_boxes,confidence_score,threshold):
#如果没有边界框,将返回空列表
if len(bounding_boxes)==0:
return [],[]
#边界框
boxes=np.array(bounding_boxes)
#边界框的坐标
start_x=boxes[:,0]
start_y=boxes[:,1]
end_x=boxes[:,2]
end_y=boxes[:,3]
#边界框的置信度
score=np.array(confidence_score)
#经过选择后的边界框
picked_boxes=[]
picked_score=[]
#计算边界框的面积
areas=(end_x-start_x+1)*(end_y-start_y+1)
#对边界框的置信度进行排序
order=np.argsort(score)#按升序排列
print("order:",order)
print("-1:",order[:-1])
#对边界框进行迭代
while order.size>0:
#最大置信度的索引
index=order[-1]#order最后一位,置信度最高,index=0
print("index:",index)
#选择置信度最大的边框
picked_boxes.append(bounding_boxes[index])
picked_score.append(confidence_score[index])
#计算IOU的坐标
x1=np.maximum(start_x[index],start_x[order[:-1]])
print("x1:",x1)
x2=np.minimum(end_x[index],end_x[order[:-1]])
print("x2:",x2)
y1=np.maximum(start_y[index],start_y[order[:-1]])
print("y1:",y1)
y2=np.minimum(end_y[index],end_y[order[:-1]])
print("y2:",y2)
#计算交叉区域的面积
w=np.maximum(0.0,x2-x1+1)
print("w:",w)
h=np.maximum(0.0,y2-y1+1)
print("h:",h)
intersection=w*h
print("intersection:",intersection)
#计算交叉区域和并集区域的比
ratio=intersection/(areas[index]+areas[order[:-1]]-intersection)
print("ratio:",ratio)
left=np.where(ratio<threshold)
print("left:",left)
order=order[left]
print("order:",order)
return picked_boxes,picked_score
#读入图像
image=cv2.imread("images/nms.jpg")
cv2.imshow("src",image)
#复制图像
org=image.copy()
#边界框大小
bounding_boxes=[(187,82,337,317),(150,67,305,282),(246,121,368,304)]
confidence_score=[0.9,0.75,0.8]
#画边界框的参数
font=cv2.FONT_HERSHEY_SIMPLEX
font_scale=1
thickness=2
#IOU的阈值
threshold=0.4
#画出边界框并标记置信度
for(start_x,start_y,end_x,end_y),confidence in zip(bounding_boxes,confidence_score):
(w,h),baseline=cv2.getTextSize(str(confidence),font,font_scale,thickness)
cv2.rectangle(org,(start_x,start_y-(2*baseline+5)),(start_x+w,start_y),(0,0,255),-1)
cv2.rectangle(org,(start_x,start_y),(end_x,end_y),(0,255,255),2)
cv2.putText(org,str(confidence),(start_x,start_y),font,font_scale,(0,0,0),thickness)
cv2.imshow("org",org)
#使用非极大值抑制算法
picked_boxes,picked_score=nms(bounding_boxes,confidence_score,threshold)
#画出经过nms算法处理过后的边界框和置信度
for (start_x,start_y,end_x,end_y),confidence in zip(picked_boxes,picked_score):
#获取字符串的宽度和高度
(w,h),baseline=cv2.getTextSize(str(confidence),font,font_scale,thickness)
cv2.rectangle(image,(start_x,start_y-(2*baseline+5)),(start_x+w,start_y),(0,0,255),-1)
cv2.rectangle(image,(start_x,start_y),(end_x,end_y),(0,255,255),2)
#在图像中显示文本字符串
cv2.putText(image,str(confidence),(start_x,start_y),font,font_scale,(0,0,0),thickness)
cv2.imshow("nms",image)
cv2.waitKey(0)