目录
2.2.5 构造bounding boxes并组合置信度得分、类型信息
1、写在前面
CenterNet是2019年的一篇论文Objects as Points中提出的网络。由于其结构简单、可扩展性强、模型转换方便的优点,至今仍很受欢迎。其GitHub地址为:https://github.com/xingyizhou/CenterNet。
整个模型包括三部分:backbone、上采样、heads。其中,backbone部分即为传统的分类模型去掉fc层,用于提取高级语义特征;上采样部分用于从低分辨率的feature map恢复部分分辨率,提高最终输出的feature map的分辨率;heads部分则得到三部分输出:每一类目标中心点的heatmap、包含每个目标尺寸(W、H)的wh、中心点偏移量offset。
对于最终输出的heatmap、wh、offset,其shape分别为:[B, C, H, W]、[B, 2, H, W]、[B, 2, H, W]。其中,B为batch大小,C为类别个数、H和W为feature map的高和宽。
我们今天要讲的后处理,即是对heatmap、wh、offset进行的处理,以得到最终bounding-boxes形式的输出,主要以源码中“src/lib/models/decode.py”的代码为讲解对象,介绍整个后处理的逻辑,并对其中的NMS、topK等过程进行逐行解释。
2、后处理源码解析
2.1 CenterNet推理过程
推理的完整流程为:
- 输入一张或者一个batch的图片,经过backbone(这里包括下采样和上采样)后输出feature map,尺寸为原图尺寸的1/4;
- 然后送入三个分支,分别得到heatmap、wh、offset;
- 最后通过后处理过程,得到bounding boxes。
而后处理仅在在推理阶段使用。
推理时,通过以下代码,可以得到hm、wh、reg,也即上文所说的heatmap、wh、offset:
def process(self, images, return_time=False):
with torch.no_grad():
output = self.model(images)[-1]
hm = output['hm'].sigmoid_()
wh = output['wh']
reg = output['reg'] if self.opt.reg_offset else None
然后,将这三者送入后处理流程:
dets = ctdet_decode(hm, wh, reg=reg, cat_spec_wh=self.opt.cat_spec_wh, K=self.opt.K)
2.2 后处理源码解析
ctdet_decode函数定义在“src/lib/models/decode.py”中:
def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
batch, cat, height, width = heat.size()
# 2.2.1 NMS
# 如果参数heat传进来之前没有进行sigmoid,就需要先使用下面的方法归一化heat中的值
# heat = torch.sigmoid(heat)
# perform nms on heatmaps
heat = _nms(heat)
# 2.2.2 TopK
scores, inds, clses, ys, xs = _topk(heat, K=K)
# 2.2.3 提取reg和wh
if reg is not None:
reg = _transpose_and_gather_feat(reg, inds)
reg = reg.view(batch, K, 2)
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
wh = _transpose_and_gather_feat(wh, inds)
# 2.2.4 是否对每一类分别设置wh
if cat_spec_wh:
wh = wh.view(batch, K, cat, 2)
clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 2).long()
wh = wh.gather(2, clses_ind).view(batch, K, 2)
else:
wh = wh.view(batch, K, 2)
# 2.2.5 构造bounding boxes并组合置信度得分、类型信息
clses = clses.view(batch, K, 1).float()
scores = scores.view(batch, K, 1)
bboxes = torch.cat([xs - wh[..., 0:1] / 2,
ys - wh[..., 1:2] / 2,
xs + wh[..., 0:1] / 2,
ys + wh[..., 1:2] / 2], dim=2)
detections = torch.cat([bboxes, scores, clses], dim=2)
return detections
下面,根据该函数的每个部分,分别进行解析。
2.2.1 NMS
从上述代码可以看出,首先需要对heat进行NMS处理,这里的NMS与传统的Anchor-based的检测算法不同,Anchor-based类算法的NMS是基于IOU进行的过滤,而CenterNet里面的NSM及其简单,仅仅是提取heatmap中的峰值,仅用一个3*3的maxpooling即可。NMS的代码如下,每一句我都添加了注释:
def _nms(heat, kernel=3):
# 设置padding值,使得经过maxpooling后尺寸不变
pad = (kernel - 1) // 2
# 利用maxpooling将峰值保留,而非峰值部分的信息被抹去了
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=pad)
# 将峰值的位置索引设为True,非峰值为False
keep = (hmax == heat).float()
# 返回结果中,峰值部分值保留,非峰值部分值为0
return heat * keep
2.2.2 Top K
经过NMS之后,紧接着就是topK操作了,这一步的目的是得到置信度排名前K个中心点的置信度得分、索引、类别、中心点坐标,代码如下,同样都添加了注释:
# 这里的scores是传入的heatmap
def _topk(scores, K=40):
batch, cat, height, width = scores.size()
# 首先将scores的H和W两个维度合并展开,然后利用torch.topk函数得到排序后的值及其索引,结果为:
# topk_scores size:(B,num_cls, K) float
# topk_inds size:(B,num_cls, K) int
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
# 得到中心点坐标(x,y), 其中inds = y * W + x, 现在知道inds,逆过程即为求x,y
topk_inds = topk_inds % (height * width) # 取%是为了使索引不至于超出范围
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()
# 对每个峰值,确定其类别(当多个类的中心点重合时,只能保留一个置信度最大的)
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
# 输出的维度:
# topk_score size:(B, K) float
# topk_ind size:(B, K) int
# 得到top K个目标的类别,在多个类别时,topk_ind除以K可以将同类的点统一到一个值
topk_clses = (topk_ind / K).int()
# top K 个点的索引
topk_inds = _gather_feat(
topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
# top K个中心点坐标
topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
2.2.3 提取reg和wh
接下来,就是提取topK个中心点对应的reg和wh。如果使用了reg也即中心点偏移,则先提取reg然后再将中心点坐标加上偏移量;如果没有使用reg,则直接将中心点坐标加上0.5的偏移。对reg和wh的topK提取主要使用了“_transpose_and_gather_feat”方法,其代码如下,同样逐行进行了注释说明:
def _gather_feat(feat, ind, mask=None):
# 第三个维度的值,如reg是中心点(x,y)的偏移量,为2
dim = feat.size(2)
# 将ind从[B, K]转换为[B, K, 1], 然后使用expand扩展为[B, K, dim]
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
# 从第二个维度,按照ind提供的的索引,提取对应的元素
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def _transpose_and_gather_feat(feat, ind):
# 对输入feat先进行维度置换,将第二个维度换到最后面
feat = feat.permute(0, 2, 3, 1).contiguous()
# 将中间两个维度合并到一维,并后的维度其维数为feature map 的宽高乘积,
feat = feat.view(feat.size(0), -1, feat.size(3))
# 调用_gather_feat方法,从feat提取ind指定索引位置的元素
feat = _gather_feat(feat, ind)
return feat
2.2.4 是否对每一类分别设置wh
接下来的cat_spec_wh代表是否对每个类别分别设置了wh。如果是,则wh原来的维度其实是[B, C*2, H, W],经过_transpose_and_gather_feat处理之后是[B, K, C*2],所以就需要将最后一个维度展开,变为[B, K, C, 2],然后调用gather方法从C这个维度开始按照clses_ind提取topK个中心点对应的宽高信息;如果否,wh经过_transpose_and_gather_feat处理之后就直接是[B, K, 2]了,else后面的那句“wh = wh.view(batch, K, 2)”我个人感觉应该就可以不用了(对这一句不知道理解的对不对,有知道的小伙伴烦请留言告知~)。
2.2.5 构造bounding boxes并组合置信度得分、类型信息
最后一部分,就是先把clses、scores展开到[B, K, 1],然后利用中心点坐标(xs, ys)和宽高信息wh,组合成为bounding boxes,再把boxes、scores、clses拼接起来,作为函数返回值。
3、写在后面
至此,CenterNet的后处理源码就解析完了,主要结合的是官方源码的后处理部分,文章对每一步涉及到的操作都做了注解,相信看一遍就会明白;最好的建议,还是结合代码,使用一个示例运行一遍,看一看每一步运算后,tensor是如何变化的,我也是这样来一步步做的,也因此有了这篇解析。这篇文章的目的就是将后处理过程吃透,便于我们在推理过程的使用,同时作为一个备忘,用于我后续回看。那对于网络的forward部分,我们则可以将其转为ONNX、TensorRT等形式加速计算,然后将结果利用该后处理过程得到我们想要的边框、类别、置信度等信息。