PointFeatureEncoder模块解析
PointFeatureEncoder在forward函数时的data_dict与初始化参数如下所示:
这里使用的点云特征就是xyz+反射强度,具体作用是旋转是否使用xyz位置特征作为点云特征,否则在kitti数据集中就只有反射强度这一维的特征。不过在kitti数据集的处理流程中,这一步其实没有带来变换,唯一的变化是在data_dict中增加了data_dict['use_lead_xyz'] = True
这一步。
整个类的实现代码如下:
import numpy as np
# 点云特征编码的基类
class PointFeatureEncoder(object):
def __init__(self, config, point_cloud_range=None):
super().__init__()
self.point_encoding_config = config
assert list(self.point_encoding_config.src_feature_list[0:3]) == ['x', 'y', 'z']
# 在pointpillars中used和src使用的特征都是4种:位置信息xyz + 点强度信息intensity
self.used_feature_list = self.point_encoding_config.used_feature_list
self.src_feature_list = self.point_encoding_config.src_feature_list
self.point_cloud_range = point_cloud_range
@property
def num_point_features(self):
return getattr(self, self.point_encoding_config.encoding_type)(points=None)
def forward(self, data_dict):
"""
Args:
data_dict:
points: (N, 3 + C_in)
...
Returns:
data_dict:
points: (N, 3 + C_out),
use_lead_xyz: whether to use xyz as point-wise features
...
"""
data_dict['points'], use_lead_xyz = getattr(self, self.point_encoding_config.encoding_type)( # (N, 4) , True
data_dict['points']
)
data_dict['use_lead_xyz'] = use_lead_xyz # True
if self.point_encoding_config.get('filter_sweeps', False) and 'timestamp' in self.src_feature_list: # False
max_sweeps = self.point_encoding_config.max_sweeps
idx = self.src_feature_list.index('timestamp')
dt = np.round(data_dict['points'][:, idx], 2)
max_dt = sorted(np.unique(dt))[min(len(np.unique(dt))-1, max_sweeps-1)]
data_dict['points'] = data_dict['points'][dt <= max_dt]
return data_dict
def absolute_coordinates_encoding(self, points=None):
if points is None:
num_output_features = len(self.used_feature_list)
return num_output_features
assert points.shape[-1] == len(self.src_feature_list)
point_feature_list = [points[:, 0:3]] # 提取xyz坐标构建为一个列表元素,剩余的特征比如反射强度再同样作为一个列表元素
for x in self.used_feature_list:
if x in ['x', 'y', 'z']:
continue
idx = self.src_feature_list.index(x) # 3
point_feature_list.append(points[:, idx:idx+1]) # 索引特征
point_features = np.concatenate(point_feature_list, axis=1) # 特征进行拼接
return point_features, True