原文:Associative Embedding: End-to-End Learning for Joint Detection and Grouping
代码:princeton-vl/pose-ae-train
Abstract
文章提出了一种全新的用于detection和grouping的监督方法,Associate Embedding。意外的多人姿态估计、实例分割、多目标跟踪问题都采用了two-stages,先detect后group的方式。文章提出的associate embedding 同时输出detection 和 grouping的结果,并在多人姿态估计任务中的MPII和COCO数据集上达到了SOTA。
Introduction
Associate embedding的基本思想是为每一个detection都分配一个vector作为tag来指示聚类的分配。所以拥有同一个tag的detections聚类成一个集合,构建出一个人体实例。对于每个实例有m个关键点的多人姿态估计任务,网络共输出m张关键点检测的heatmap,和m个tag图,每个tag图分别来指示对于某一关节来说,每个像素分别应该属于哪个人体实例。
具体的解码方式是:先从heatmap上提取出关键点所在的位置,去对于的tag图中找到tag值,把tag值相近的关键点聚类成一个人体实例。
网络训练时,Loss函数只需要鼓励同一个人体实例的不同关节点的对应的tag值相近,而不同的人体实例的关节点tag值相异即可。只需要学习tag应该相同或不同,网络可以自由学习具体的绝对的标签是什么。(用0和1可以区别出来,2和3也可以区别出来,效果都是一样的,网络自己决定用哪组数字来区别,所以对于tag也没有具体的GT。)
Related Work
略
Approach
Network Architecture
用了Hourglass的网络结构做了一个改进,为了多人姿态,提高了hourglass各层的分辨率,每个网络中的residual modules换成了 卷积。
Detection and Grouping
Detection的部分和单人姿态类似,为m类关键点分别生成m张heatmap,为不同人的同类关节点在用一张heatmap上无差别的表示,所以一张理想的heatmap是为每个人的这类关节点都有一个不同的peak。训练时GT用2D高斯来构建peak,loss函数用了MSE。
Grouping的部分,也是本文的关键创新点。假设已经从detection的部分获得了共m类关键点的位置,每类关键点都有好几个,需要把这些关键点进行grouping,构建出完整的人体实例。网络在detection的同时额外生成了一个embedding来指示每个像素应该属于哪个人体实例。(对于不同类的实例来说都有一个不同的embedding)。文章提出embedding的纬度并不重要,高维如果能实现,那低维一定也可以,比较只是用来区别人体实例,所以文章里用了1D的embedding。
tag指示了每个检测到的关节点属于哪一个人体实例。每个detection heatmap都各自对应的tag map。因此对于m各人体关节,用m 个 detection heatmap用来检测 和 m 个 tag map用来grouping。为了把每个检测到关节匹配到不同的人体实例,文章首先检索出每个检测出的关键点的峰值处像素所对应的tag map中的tag值,通过比较各个关节点的值来聚类 组合人体。
Grouping 的 loss 函数评估预测出的tag能是否个GT的聚类结果相同。具体地,我们检索每个人体节点的GT位置的对应的tag map中预测出的tag值(并不是预测出的关节点位置对应的tag值)
每个人体各关节点的GT位置的tag求均值作为改人体的reference embedding:
之后根据同意人体的tag相近,不同人体的reference embedding相异的原则,计算loss:
Parsing Network Output
为了生成最后的一系列检测结果,我们一个一个关节点遍历。顺序从头和躯干逐渐移动至关节。假设我们首先从脖子开始,组成我们的初始的人体实例pool。对于下一个关节点,
我们挑选出与人体pool最匹配的关节点。每个关节点由它的分数与tag组成,每个人体的reference embedding 由当前关节点的tag均值决定。
我们比较这些embedding之间的距离,我们贪婪的分配 在embedding距离之内的 响应最高的 关节点。如果新的joint没有任何匹配的人体(和任何tag都不相近),把这个joint作为一个新的人体实例。知道所有的joint都分配完成。
代码实现部分:
首先在k个heatmap上得到最多m个detection,在提取出对应的tag值,得到:
ans = {
tag_k:
loc_k:
val_k:
}
生成权重矩阵,使用KM算法找到二分图最佳匹配。(按关节点顺序遍历)
对于Missing Joints 我们在评估阶段需要确保每个人体实例都有完整的所有关节点,我们挑选出所有tag相近的位置,在其中挑选出响应最高的点作为其关节点。这个位置在之前的detection阶段可能并没有达到被detect出的阀值。
对于Multiscale Evaluation我们对多尺度的heatmaps在resize之后逐元素取均值,对于tag maps,我们resize之后对m个scale的tag maps进行逐元素的concat,tag作为m维的vector。
HigherHRNet 代码中的associate embedding部分。
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao ([email protected])
# Modified by Bowen Cheng ([email protected])
# ------------------------------------------------------------------------------
import torch
import torch.nn as nn
class AELoss(nn.Module):
def __init__(self, loss_type):
super().__init__()
self.loss_type = loss_type
def singleTagLoss(self, pred_tag, joints):
"""
associative embedding loss for one image
"""
tags = []
pull = 0
for joints_per_person in joints:
tmp = []
for joint in joints_per_person:
if joint[1] > 0:
tmp.append(pred_tag[joint[0]])
if len(tmp) == 0:
continue
tmp = torch.stack(tmp)
tags.append(torch.mean(tmp, dim=0))
pull = pull + torch.mean((tmp - tags[-1].expand_as(tmp))**2)
num_tags = len(tags)
if num_tags == 0:
return make_input(torch.zeros(1).float()), \
make_input(torch.zeros(1).float())
elif num_tags == 1:
return make_input(torch.zeros(1).float()), \
pull/(num_tags)
tags = torch.stack(tags)
size = (num_tags, num_tags)
A = tags.expand(*size)
B = A.permute(1, 0)
diff = A - B
if self.loss_type == 'exp':
diff = torch.pow(diff, 2)
push = torch.exp(-diff)
push = torch.sum(push) - num_tags
elif self.loss_type == 'max':
diff = 1 - torch.abs(diff)
push = torch.clamp(diff, min=0).sum() - num_tags
else:
raise ValueError('Unkown ae loss type')
return push/((num_tags - 1) * num_tags) * 0.5, \
pull/(num_tags)
def forward(self, tags, joints):
"""
accumulate the tag loss for each image in the batch
"""
pushes, pulls = [], []
joints = joints.cpu().data.numpy()
batch_size = tags.size(0)
for i in range(batch_size):
push, pull = self.singleTagLoss(tags[i], joints[i])
pushes.append(push)
pulls.append(pull)
return torch.stack(pushes), torch.stack(pulls)
def test_ae_loss():
import numpy as np
t = torch.tensor(
np.arange(0, 32).reshape(1, 2, 4, 4).astype(np.float)*0.1,
requires_grad=True
)
t.register_hook(lambda x: print('t', x))
ae_loss = AELoss(loss_type='exp')
joints = np.zeros((2, 2, 2))
joints[0, 0] = (3, 1)
joints[1, 0] = (10, 1)
joints[0, 1] = (22, 1)
joints[1, 1] = (30, 1)
joints = torch.LongTensor(joints)
joints = joints.view(1, 2, 2, 2)
t = t.contiguous().view(1, -1, 1)
l = ae_loss(t, joints)
print(l)
HigherHRNet 代码中的Grouping部分。
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Some code is from https://github.com/princeton-vl/pose-ae-train/blob/454d4ba113bbb9775d4dc259ef5e6c07c2ceed54/utils/group.py
# Written by Bin Xiao ([email protected])
# Modified by Bowen Cheng ([email protected])
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from munkres import Munkres
import numpy as np
import torch
def py_max_match(scores):
m = Munkres()
tmp = m.compute(scores)
tmp = np.array(tmp).astype(np.int32)
return tmp
def match_by_tag(inp, params):
assert isinstance(params, Params), 'params should be class Params()'
tag_k, loc_k, val_k = inp
default_ = np.zeros((params.num_joints, 3 + tag_k.shape[2]))
joint_dict = {}
tag_dict = {}
for i in range(params.num_joints):
idx = params.joint_order[i]
tags = tag_k[idx]
joints = np.concatenate(
(loc_k[idx], val_k[idx, :, None], tags), 1
)
mask = joints[:, 2] > params.detection_threshold
tags = tags[mask]
joints = joints[mask]
if joints.shape[0] == 0:
continue
if i == 0 or len(joint_dict) == 0:
for tag, joint in zip(tags, joints):
key = tag[0]
joint_dict.setdefault(key, np.copy(default_))[idx] = joint
tag_dict[key] = [tag]
else:
grouped_keys = list(joint_dict.keys())[:params.max_num_people]
grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys]
if params.ignore_too_much \
and len(grouped_keys) == params.max_num_people:
continue
diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :]
diff_normed = np.linalg.norm(diff, ord=2, axis=2)
diff_saved = np.copy(diff_normed)
if params.use_detection_val:
diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3]
num_added = diff.shape[0]
num_grouped = diff.shape[1]
if num_added > num_grouped:
diff_normed = np.concatenate(
(
diff_normed,
np.zeros((num_added, num_added-num_grouped))+1e10
),
axis=1
)
pairs = py_max_match(diff_normed)
for row, col in pairs:
if (
row < num_added
and col < num_grouped
and diff_saved[row][col] < params.tag_threshold
):
key = grouped_keys[col]
joint_dict[key][idx] = joints[row]
tag_dict[key].append(tags[row])
else:
key = tags[row][0]
joint_dict.setdefault(key, np.copy(default_))[idx] = \
joints[row]
tag_dict[key] = [tags[row]]
ans = np.array([joint_dict[i] for i in joint_dict]).astype(np.float32)
return ans
class Params(object):
def __init__(self, cfg):
self.num_joints = cfg.DATASET.NUM_JOINTS
self.max_num_people = cfg.DATASET.MAX_NUM_PEOPLE
self.detection_threshold = cfg.TEST.DETECTION_THRESHOLD
self.tag_threshold = cfg.TEST.TAG_THRESHOLD
self.use_detection_val = cfg.TEST.USE_DETECTION_VAL
self.ignore_too_much = cfg.TEST.IGNORE_TOO_MUCH
if cfg.DATASET.WITH_CENTER and cfg.TEST.IGNORE_CENTER:
self.num_joints -= 1
if cfg.DATASET.WITH_CENTER and not cfg.TEST.IGNORE_CENTER:
self.joint_order = [
i-1 for i in [18, 1, 2, 3, 4, 5, 6, 7, 12, 13, 8, 9, 10, 11, 14, 15, 16, 17]
]
else:
self.joint_order = [
i-1 for i in [1, 2, 3, 4, 5, 6, 7, 12, 13, 8, 9, 10, 11, 14, 15, 16, 17]
]
class HeatmapParser(object):
def __init__(self, cfg):
self.params = Params(cfg)
self.tag_per_joint = cfg.MODEL.TAG_PER_JOINT
self.pool = torch.nn.MaxPool2d(
cfg.TEST.NMS_KERNEL, 1, cfg.TEST.NMS_PADDING
)
def nms(self, det):
maxm = self.pool(det)
maxm = torch.eq(maxm, det).float()
det = det * maxm
return det
def match(self, tag_k, loc_k, val_k):
match = lambda x: match_by_tag(x, self.params)
return list(map(match, zip(tag_k, loc_k, val_k)))
def top_k(self, det, tag):
# det = torch.Tensor(det, requires_grad=False)
# tag = torch.Tensor(tag, requires_grad=False)
det = self.nms(det)
num_images = det.size(0)
num_joints = det.size(1)
h = det.size(2)
w = det.size(3)
det = det.view(num_images, num_joints, -1)
val_k, ind = det.topk(self.params.max_num_people, dim=2)
tag = tag.view(tag.size(0), tag.size(1), w*h, -1)
if not self.tag_per_joint:
tag = tag.expand(-1, self.params.num_joints, -1, -1)
tag_k = torch.stack(
[
torch.gather(tag[:, :, :, i], 2, ind)
for i in range(tag.size(3))
],
dim=3
)
x = ind % w
y = (ind / w).long()
ind_k = torch.stack((x, y), dim=3)
ans = {
'tag_k': tag_k.cpu().numpy(),
'loc_k': ind_k.cpu().numpy(),
'val_k': val_k.cpu().numpy()
}
return ans
def adjust(self, ans, det):
for batch_id, people in enumerate(ans):
for people_id, i in enumerate(people):
for joint_id, joint in enumerate(i):
if joint[2] > 0:
y, x = joint[0:2]
xx, yy = int(x), int(y)
#print(batch_id, joint_id, det[batch_id].shape)
tmp = det[batch_id][joint_id]
if tmp[xx, min(yy+1, tmp.shape[1]-1)] > tmp[xx, max(yy-1, 0)]:
y += 0.25
else:
y -= 0.25
if tmp[min(xx+1, tmp.shape[0]-1), yy] > tmp[max(0, xx-1), yy]:
x += 0.25
else:
x -= 0.25
ans[batch_id][people_id, joint_id, 0:2] = (y+0.5, x+0.5)
return ans
def refine(self, det, tag, keypoints):
"""
Given initial keypoint predictions, we identify missing joints
:param det: numpy.ndarray of size (17, 128, 128)
:param tag: numpy.ndarray of size (17, 128, 128) if not flip
:param keypoints: numpy.ndarray of size (17, 4) if not flip, last dim is (x, y, det score, tag score)
:return:
"""
if len(tag.shape) == 3:
# tag shape: (17, 128, 128, 1)
tag = tag[:, :, :, None]
tags = []
for i in range(keypoints.shape[0]):
if keypoints[i, 2] > 0:
# save tag value of detected keypoint
x, y = keypoints[i][:2].astype(np.int32)
tags.append(tag[i, y, x])
# mean tag of current detected people
prev_tag = np.mean(tags, axis=0)
ans = []
for i in range(keypoints.shape[0]):
# score of joints i at all position
tmp = det[i, :, :]
# distance of all tag values with mean tag of current detected people
tt = (((tag[i, :, :] - prev_tag[None, None, :]) ** 2).sum(axis=2) ** 0.5)
tmp2 = tmp - np.round(tt)
# find maximum position
y, x = np.unravel_index(np.argmax(tmp2), tmp.shape)
xx = x
yy = y
# detection score at maximum position
val = tmp[y, x]
# offset by 0.5
x += 0.5
y += 0.5
# add a quarter offset
if tmp[yy, min(xx + 1, tmp.shape[1] - 1)] > tmp[yy, max(xx - 1, 0)]:
x += 0.25
else:
x -= 0.25
if tmp[min(yy + 1, tmp.shape[0] - 1), xx] > tmp[max(0, yy - 1), xx]:
y += 0.25
else:
y -= 0.25
ans.append((x, y, val))
ans = np.array(ans)
if ans is not None:
for i in range(det.shape[0]):
# add keypoint if it is not detected
if ans[i, 2] > 0 and keypoints[i, 2] == 0:
# if ans[i, 2] > 0.01 and keypoints[i, 2] == 0:
keypoints[i, :2] = ans[i, :2]
keypoints[i, 2] = ans[i, 2]
return keypoints
def parse(self, det, tag, adjust=True, refine=True):
ans = self.match(**self.top_k(det, tag))
if adjust:
ans = self.adjust(ans, det)
scores = [i[:, 2].mean() for i in ans[0]]
if refine:
ans = ans[0]
# for every detected person
for i in range(len(ans)):
det_numpy = det[0].cpu().numpy()
tag_numpy = tag[0].cpu().numpy()
if not self.tag_per_joint:
tag_numpy = np.tile(
tag_numpy, (self.params.num_joints, 1, 1, 1)
)
ans[i] = self.refine(det_numpy, tag_numpy, ans[i])
ans = [ans]
return ans, scores