【自动驾驶】单目3D检测M3D-RPN解析与paddle复现

1. 简介

作者提出了一种单个的端到端区域建议网络用于多类别3D目标检测。2D和3D检测任务各自的目标是最终对一个对象的所有实例进行分类,而它们在定位目标的维数上是不同的。直观地说,我们期望能够利用2D检测的强大功能来指导和改进3D检测的性能,最好是在一个统一的框架内,而不是作为单独的组件。因此,作者重新定义了3D检测问题,使2D和3D空间都利用共享的锚点和分类目标。这样一来,可靠地对物体进行分类的角度来看,3D检测器的性能自然能够跟2D检测器媲美。因此,剩余的挑战被减少到相机坐标间内的3D定位了。

作者提出三个关键的设计来改进3D估计。

1. 提出构造3D锚点,使其在图像空间内发挥作用,使用每个锚点的3D参数的先验统计信息初始化所有锚点。因此,基于固定相机视角的一致性和2D尺度与3D深度的相关性,每个离散锚点在3D中具有很强的先验推理能力。

2. 设计了新的深度感知卷积层,(图1的blue)能够学习空间感知的特征。传统上,卷积操作倾向于空间不变【18,19】,以便在任意图像位置检测对象。然而,虽然这是低阶特征的情况,但如果能够提高对深度的认知,并假设相对一致的相机场景几何的话,高阶特征具有提升的潜力。

3.在后优化算法中使用3D --> 2D投影一致性损失来优化方向估计\theta。使用2D边界框来帮助纠正异常的\theta估计值。

概括如下:

  • 利用共享的2D和3D检测空间,建立一个独立的单目3D区域建议网络(M3D-RPN),同时利用先验统计对每个3D参数进行强初始化。
  • 提出深度感知卷积来感知3D参数估计,从而使网络能学习更多的空间级高阶特征。
  • 提出一个简单的方向估计后优化算法,该算法使用3D投影和2D检测来改善\theta的估计。
  • 在城市KITTI benchmark上取得SOTA,用于单目鸟瞰视角和使用单个多类别网络的3D检测。

方案难点:

  • 深度信息缺失,由 2D 图像预测 3D 位置困难
  • 相机传感器敏感,受环境影响(夜晚、雨天)等较大
  • 图像层面,遮挡、截断等问题严重影响感知精度

论文地址:M3D-RPN:Monocular 3D Region Proposal Network for Object Detection

Github地址: https://github.com/garrickbrazil/M3D-RPN

项目官方地址:http://cvlab.cse.msu.edu/ project-m3d-rpn.html

优秀的论文翻译:【论文翻译】M3D-RPN:Monocular 3D Region Proposal Network for Object Detection

2. 安装说明

2.1 环境要求

  • Python >= 3.6
  • paddlepaddle >= 2.0.2
  • cuda >= 9
  • boost 库
  • 常见 Python 库

In [2]

# 安装库
! pip install shapely
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting shapely
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d1/ec/3038263d69a0065d3ab6944ae839f5f00896efd29b13ae62d73c00345b95/Shapely-1.8.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 9.8 MB/s eta 0:00:00:00:0100:01
Installing collected packages: shapely
Successfully installed shapely-1.8.2

[notice] A new release of pip available: 22.1.2 -> 22.2.1
[notice] To update, run: pip install --upgrade pip

2.2 解压数据及代码

项目代码在 M3D-RPN-2.0.tar 文件中,数据集在 kitti.tar 文件中,解压到合适路径即可使用。

In [ ]

# 如果希望解压到其他目录
# 可选择其他路径(默认 /home/aistudio )
! tar xf ~/data/data141443/kitti.tar
! unzip -qo ~/data/data141443/M3D-RPN-2.0.zip
! rm -rf __MACOSX

3. 数据准备

In [3]

# 删除已有软连接
! rm -rf ~/M3D-RPN-2.0/dataset/kitti_split1/training
! rm -rf ~/M3D-RPN-2.0/dataset/kitti_split1/validation

In [4]

%cd ~/M3D-RPN-2.0/
! python dataset/kitti_split1/setup_split.py
! sh dataset/kitti_split1/devkit/cpp/build.sh
! cd lib/nms && make
/home/aistudio/M3D-RPN-2.0
Linking train
Linking val
Done
evaluate_object.cpp: In function ‘void saveAndPlotPlots(std::__cxx11::string, std::__cxx11::string, std::__cxx11::string, std::vector<double>*, bool)’:
evaluate_object.cpp:763:11: warning: ignoring return value of ‘int system(const char*)’, declared with attribute warn_unused_result [-Wunused-result]
     system(command);
     ~~~~~~^~~~~~~~~
evaluate_object.cpp:768:9: warning: ignoring return value of ‘int system(const char*)’, declared with attribute warn_unused_result [-Wunused-result]
   system(command);
   ~~~~~~^~~~~~~~~
evaluate_object.cpp:770:9: warning: ignoring return value of ‘int system(const char*)’, declared with attribute warn_unused_result [-Wunused-result]
   system(command);
   ~~~~~~^~~~~~~~~
evaluate_object.cpp:772:9: warning: ignoring return value of ‘int system(const char*)’, declared with attribute warn_unused_result [-Wunused-result]
   system(command);
   ~~~~~~^~~~~~~~~
evaluate_object.cpp: In function ‘bool eval(std::__cxx11::string, Mail*)’:
evaluate_object.cpp:786:9: warning: ignoring return value of ‘int system(const char*)’, declared with attribute warn_unused_result [-Wunused-result]
   system(("mkdir " + plot_dir).c_str());
   ~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
evaluate_object.cpp: In function ‘int32_t main(int32_t, char**)’:
evaluate_object.cpp:925:11: warning: ignoring return value of ‘int system(const char*)’, declared with attribute warn_unused_result [-Wunused-result]
     system(("rm -r results/" + result_sha).c_str());
     ~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
python setup.py build_ext --inplace
running build_ext
skipping 'cpu_nms.c' Cython extension (up-to-date)
skipping 'gpu_nms.cpp' Cython extension (up-to-date)
rm -rf build

整个数据集包含图片 images,标签 labels 和相机参数 calib,每个标签文件种包含以下字段:

type---物体类别

truncated---是否截断

occluded---是否遮挡

alpha---观测角

bbox---障碍物2D框

dimension---障碍物的3D大小

location---障碍物的3D底面中心点位置

rotation_y---障碍物的朝向角

最终数据集文件组织结构为:

kitti
└── training
    ├── calib
    ├── image_2
    └── label_2

4. 模型结构

单目 3D 检测提供两种选择:基于 anchor 的方案和 anchor-free 的方案

基于anchor:从图像中估计出 3D 检测框,也无需预测中间的 3D 场景表示,可以直接利用一个区域提案网络,生成给点图像的 3D 锚点。不同于以往与类别无关的 2D 锚点,3D 锚点的形状通常与其语义标签有很强的相关性。

Anchor-free:将 2D 检测方法 CenterNet 扩展到基于图像的 3D 检测器,该框架将对象编码为单个点(对象的中心点)并使用关键点估计来找到它。此外,几个平行的头被用来估计物体的其他属性,包括深度、尺寸、位置和方向。

采用 anchor 的方法使用了 3D 障碍物的平均信息作为先验知识,3D 检测效果实际落地更好,所以我们采用经典的基于 anchor 的方法。在骨干网络部分,我们选择的是 DenseNet,这种网络建立的是前面所有层与后面层的密集连接,实现特征重用,有着省参数,扛过拟合等优点。

根据单目 3D 检测实时性的要求,这里我们选择了 DenseNet121 作为我们的骨干网络。

 数据过滤:根据 bbox 可见程度、大小来过滤每个 bbox 标签,根据有无保留 bbox 过滤每张图片,整体平衡前后背景,保证训练的稳定性。

数据增强:主要使用 RandomFlip、Resize 两种数据增强策略

Anchor定义:模型输出2D anchor定义3D anchor定义

后处理优化: 根据将 3D 相关信息组成 3D 框,投影到图像上得到投影的八点框,取八点最小外接包围框与 2D 预测结果算 IOU,通过不断的调整旋转角 ry 或深度 z,来使得 IOU 最小。此算法利用了 2D 检测的结果要比 3D 检测的结果准确的先验知识,用 2D 框来纠正预测的 3D 属性,来达到优化 3D 定位精度的目的。整体框架如下图所示:

经过调整后,在 car 类前后效果对比如下:

3D detection Easy Mod Hard
优化前 16.57 13.82 12.30
优化后 19.09 15.70 13.15
增量 +2.52 +1.88 +0.85

5. 模型训练

训练被拆分成了热身配置和主要配置。详细信息可查看 config 中的配置。

首先,在启动模型训练之前,可以修改配置文件中相关内容, 主要包括数据集的地址以及类别数量。对应到配置文件中的位置如下所示:

基础配置
  solver_type: 		'sgd'
  lr:  			0.004
  momentum: 		0.9
  weight_decay: 	0.0005
  max_iter: 		50000
  snapshot_iter: 	10000
  display: 		20
  do_test: 		True
数据集路径
  dataset_test: 'kitti_split1'
  datasets_train:
    name: 		'kitti_split1'
    anno_fmt: 		'kitti_det'
    im_ext: 		'.png'
    scale: 		1

还有一些其他的配置诸如优化器配置标签信息检测器样本等,可以在 config 目录下查看。

  • 启动热身配置训练 (不包含 depth-aware)

In [1]

! python train.py --conf=kitti_3d_multi_warmup
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
2022-07-30 15:44:24,451-INFO: Loading KITTI dataset from dataset ...
2022-07-30 15:44:24,451-INFO: Loading imgs_label kitti_split1
2022-07-30 15:44:35,085-INFO: 1000/3712, dt: 0.0106, eta: 28.8s
2022-07-30 15:44:48,086-INFO: 2000/3712, dt: 0.0118, eta: 20.2s
2022-07-30 15:45:01,280-INFO: 3000/3712, dt: 0.0123, eta: 8.7s
2022-07-30 15:45:11,044-INFO: weighted respectively as 1.00 and 0.00
2022-07-30 15:45:11,045-INFO: Found 3534 foreground and 178 empty images
W0730 15:45:11.055980  1527 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0730 15:45:11.060575  1527 device_context.cc:465] device: 0, cuDNN Version: 7.6.
load pretrain model from  pretrained_model/densenet.pdparams
2022-07-30 15:45:27,696-INFO: iter: 20, acc (bg: 0.68, fg: 0.28, iou: 0.57), loss (bbox_3d: 2.9470, cls: 1.7128, iou: 0.5649), misc (ry: 1.67, z: 3.07), dt: 3.16, eta: 43.9h
epoch 0 | batch step 20 | iter 20, batch cost: 0.49549, loss 4.922
2022-07-30 15:45:40,023-INFO: iter: 40, acc (bg: 1.00, fg: 0.17, iou: 0.58), loss (bbox_3d: 2.7112, cls: 0.9902, iou: 0.5598), misc (ry: 1.53, z: 2.70), dt: 1.89, eta: 26.2h
epoch 0 | batch step 40 | iter 40, batch cost: 0.42486, loss 5.402
2022-07-30 15:45:52,233-INFO: iter: 60, acc (bg: 1.00, fg: 0.10, iou: 0.58), loss (bbox_3d: 2.8668, cls: 0.8062, iou: 0.5496), misc (ry: 1.70, z: 2.58), dt: 1.46, eta: 20.3h
epoch 0 | batch step 60 | iter 60, batch cost: 0.53743, loss 7.218
2022-07-30 15:46:03,758-INFO: iter: 80, acc (bg: 1.00, fg: 0.18, iou: 0.59), loss (bbox_3d: 2.6783, cls: 0.6714, iou: 0.5409), misc (ry: 1.73, z: 2.85), dt: 1.24, eta: 17.2h
epoch 0 | batch step 80 | iter 80, batch cost: 0.52226, loss 3.733
2022-07-30 15:46:16,135-INFO: iter: 100, acc (bg: 1.00, fg: 0.35, iou: 0.60), loss (bbox_3d: 2.4326, cls: 0.5723, iou: 0.5181), misc (ry: 1.44, z: 2.96), dt: 1.12, eta: 15.5h
epoch 0 | batch step 100 | iter 100, batch cost: 0.47307, loss 2.757
2022-07-30 15:46:28,455-INFO: iter: 120, acc (bg: 1.00, fg: 0.42, iou: 0.60), loss (bbox_3d: 2.0926, cls: 0.5493, iou: 0.5161), misc (ry: 1.41, z: 2.62), dt: 1.03, eta: 14.3h
epoch 0 | batch step 120 | iter 120, batch cost: 0.47274, loss 3.074
2022-07-30 15:46:40,738-INFO: iter: 140, acc (bg: 0.99, fg: 0.48, iou: 0.61), loss (bbox_3d: 2.1851, cls: 0.4690, iou: 0.5022), misc (ry: 1.66, z: 2.50), dt: 0.97, eta: 13.5h
epoch 0 | batch step 140 | iter 140, batch cost: 0.48107, loss 2.733
2022-07-30 15:46:52,978-INFO: iter: 160, acc (bg: 0.99, fg: 0.59, iou: 0.62), loss (bbox_3d: 2.1454, cls: 0.5109, iou: 0.4816), misc (ry: 1.54, z: 2.75), dt: 0.93, eta: 12.9h
epoch 0 | batch step 160 | iter 160, batch cost: 0.45592, loss 3.518
2022-07-30 15:47:05,190-INFO: iter: 180, acc (bg: 1.00, fg: 0.38, iou: 0.62), loss (bbox_3d: 2.2131, cls: 0.5824, iou: 0.4913), misc (ry: 1.30, z: 2.69), dt: 0.89, eta: 12.4h
epoch 0 | batch step 180 | iter 180, batch cost: 0.45615, loss 2.335
2022-07-30 15:47:17,361-INFO: iter: 200, acc (bg: 1.00, fg: 0.44, iou: 0.61), loss (bbox_3d: 2.2246, cls: 0.6390, iou: 0.5003), misc (ry: 1.39, z: 2.87), dt: 0.86, eta: 12.0h
epoch 0 | batch step 200 | iter 200, batch cost: 0.47935, loss 4.167
2022-07-30 15:47:29,857-INFO: iter: 220, acc (bg: 0.99, fg: 0.44, iou: 0.62), loss (bbox_3d: 2.1621, cls: 0.5316, iou: 0.4924), misc (ry: 1.69, z: 2.62), dt: 0.84, eta: 11.7h
epoch 0 | batch step 220 | iter 220, batch cost: 0.50687, loss 2.841
2022-07-30 15:47:42,019-INFO: iter: 240, acc (bg: 1.00, fg: 0.65, iou: 0.64), loss (bbox_3d: 1.8156, cls: 0.4068, iou: 0.4596), misc (ry: 1.28, z: 2.81), dt: 0.82, eta: 11.4h
epoch 0 | batch step 240 | iter 240, batch cost: 0.50188, loss 1.798
2022-07-30 15:47:54,038-INFO: iter: 260, acc (bg: 0.99, fg: 0.50, iou: 0.62), loss (bbox_3d: 2.0267, cls: 0.5757, iou: 0.4855), misc (ry: 1.51, z: 2.03), dt: 0.81, eta: 11.1h
epoch 0 | batch step 260 | iter 260, batch cost: 0.54614, loss 2.679
2022-07-30 15:48:05,952-INFO: iter: 280, acc (bg: 0.99, fg: 0.72, iou: 0.65), loss (bbox_3d: 1.6289, cls: 0.3576, iou: 0.4421), misc (ry: 1.36, z: 1.91), dt: 0.79, eta: 10.9h
epoch 0 | batch step 280 | iter 280, batch cost: 0.48394, loss 3.034
2022-07-30 15:48:16,912-INFO: iter: 300, acc (bg: 0.90, fg: 0.36, iou: nan), loss (bbox_3d: 2.6962, cls: 1.0629, iou: nan), misc (ry: 1.59, z: 2.99), dt: 0.77, eta: 10.7h
epoch 0 | batch step 300 | iter 300, batch cost: 0.32453, loss nan
2022-07-30 15:48:26,841-INFO: iter: 320, acc (bg: 0.81, fg: 0.15, iou: nan), loss (bbox_3d: 2.7870, cls: 1.7044, iou: nan), misc (ry: 1.52, z: 3.33), dt: 0.76, eta: 10.5h
epoch 0 | batch step 320 | iter 320, batch cost: 0.43582, loss 5.697
2022-07-30 15:48:36,756-INFO: iter: 340, acc (bg: 0.94, fg: 0.07, iou: nan), loss (bbox_3d: 2.8472, cls: 1.6913, iou: nan), misc (ry: 1.36, z: 3.00), dt: 0.74, eta: 10.2h
epoch 0 | batch step 340 | iter 340, batch cost: 0.34816, loss nan
2022-07-30 15:48:46,644-INFO: iter: 360, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.7810, cls: 1.6852, iou: nan), misc (ry: 1.54, z: 3.50), dt: 0.73, eta: 10.0h
epoch 0 | batch step 360 | iter 360, batch cost: 0.31263, loss nan
2022-07-30 15:48:56,450-INFO: iter: 380, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.6894, cls: 1.6709, iou: nan), misc (ry: 1.48, z: 2.73), dt: 0.72, eta: 9.9h
epoch 0 | batch step 380 | iter 380, batch cost: 0.35810, loss nan
2022-07-30 15:49:06,365-INFO: iter: 400, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.8477, cls: 1.6712, iou: nan), misc (ry: 1.60, z: 3.09), dt: 0.70, eta: 9.7h
epoch 0 | batch step 400 | iter 400, batch cost: 0.33034, loss nan
2022-07-30 15:49:15,709-INFO: iter: 420, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.8604, cls: 1.6656, iou: nan), misc (ry: 1.74, z: 3.08), dt: 0.69, eta: 9.6h
epoch 0 | batch step 420 | iter 420, batch cost: 0.35772, loss nan
2022-07-30 15:49:25,107-INFO: iter: 440, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 3.1004, cls: 1.6635, iou: nan), misc (ry: 1.73, z: 3.00), dt: 0.68, eta: 9.4h
epoch 0 | batch step 440 | iter 440, batch cost: 0.44086, loss 4.242
2022-07-30 15:49:34,778-INFO: iter: 460, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.9430, cls: 1.6531, iou: nan), misc (ry: 1.55, z: 3.01), dt: 0.67, eta: 9.3h
epoch 0 | batch step 460 | iter 460, batch cost: 0.39412, loss nan
2022-07-30 15:49:44,426-INFO: iter: 480, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.6437, cls: 1.6503, iou: nan), misc (ry: 1.71, z: 2.80), dt: 0.67, eta: 9.2h
epoch 0 | batch step 480 | iter 480, batch cost: 0.42811, loss 5.613
2022-07-30 15:49:54,486-INFO: iter: 500, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.8225, cls: 1.6487, iou: nan), misc (ry: 1.56, z: 4.15), dt: 0.66, eta: 9.1h
epoch 0 | batch step 500 | iter 500, batch cost: 0.41976, loss 4.490
2022-07-30 15:50:05,250-INFO: iter: 520, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.7804, cls: 1.6359, iou: nan), misc (ry: 1.52, z: 2.93), dt: 0.66, eta: 9.0h
epoch 0 | batch step 520 | iter 520, batch cost: 0.48338, loss 4.897
2022-07-30 15:50:16,740-INFO: iter: 540, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.6938, cls: 1.6308, iou: nan), misc (ry: 1.55, z: 3.69), dt: 0.65, eta: 9.0h
epoch 0 | batch step 540 | iter 540, batch cost: 0.40569, loss 5.528
2022-07-30 15:50:27,389-INFO: iter: 560, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 3.0550, cls: 1.6286, iou: nan), misc (ry: 1.64, z: 3.63), dt: 0.65, eta: 8.9h
epoch 0 | batch step 560 | iter 560, batch cost: 0.39938, loss nan
2022-07-30 15:50:37,931-INFO: iter: 580, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.7601, cls: 1.6200, iou: nan), misc (ry: 1.47, z: 2.83), dt: 0.64, eta: 8.8h
epoch 0 | batch step 580 | iter 580, batch cost: 0.33427, loss nan
2022-07-30 15:50:48,232-INFO: iter: 600, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.8752, cls: 1.6214, iou: nan), misc (ry: 1.70, z: 3.11), dt: 0.64, eta: 8.8h
epoch 0 | batch step 600 | iter 600, batch cost: 0.42065, loss nan
2022-07-30 15:50:57,694-INFO: iter: 620, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.4335, cls: 1.6145, iou: nan), misc (ry: 1.46, z: 2.35), dt: 0.63, eta: 8.7h
epoch 0 | batch step 620 | iter 620, batch cost: 0.41829, loss 4.738
2022-07-30 15:51:08,082-INFO: iter: 640, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.9969, cls: 1.6115, iou: nan), misc (ry: 1.57, z: 3.44), dt: 0.63, eta: 8.6h
epoch 0 | batch step 640 | iter 640, batch cost: 0.44508, loss 4.876
2022-07-30 15:51:17,536-INFO: iter: 660, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.8563, cls: 1.6088, iou: nan), misc (ry: 1.54, z: 3.32), dt: 0.63, eta: 8.6h
epoch 0 | batch step 660 | iter 660, batch cost: 0.31495, loss nan
2022-07-30 15:51:27,147-INFO: iter: 680, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.9002, cls: 1.6037, iou: nan), misc (ry: 1.69, z: 2.83), dt: 0.62, eta: 8.5h
epoch 0 | batch step 680 | iter 680, batch cost: 0.38046, loss nan
2022-07-30 15:51:36,771-INFO: iter: 700, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.7077, cls: 1.6013, iou: nan), misc (ry: 1.53, z: 3.16), dt: 0.62, eta: 8.5h
epoch 0 | batch step 700 | iter 700, batch cost: 0.44091, loss 4.932
2022-07-30 15:51:47,549-INFO: iter: 720, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 3.1246, cls: 1.5999, iou: nan), misc (ry: 1.63, z: 3.52), dt: 0.62, eta: 8.4h
epoch 0 | batch step 720 | iter 720, batch cost: 0.36298, loss nan
2022-07-30 15:51:57,578-INFO: iter: 740, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.8187, cls: 1.5975, iou: nan), misc (ry: 1.36, z: 3.08), dt: 0.61, eta: 8.4h
epoch 0 | batch step 740 | iter 740, batch cost: 0.35585, loss nan
2022-07-30 15:52:07,169-INFO: iter: 760, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.8386, cls: 1.5943, iou: nan), misc (ry: 1.48, z: 3.21), dt: 0.61, eta: 8.3h
epoch 0 | batch step 760 | iter 760, batch cost: 0.45888, loss 4.159
2022-07-30 15:52:17,427-INFO: iter: 780, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.5863, cls: 1.5915, iou: nan), misc (ry: 1.53, z: 2.91), dt: 0.61, eta: 8.3h
epoch 0 | batch step 780 | iter 780, batch cost: 0.42366, loss 5.303
2022-07-30 15:52:26,824-INFO: iter: 800, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 3.0004, cls: 1.5904, iou: nan), misc (ry: 1.48, z: 3.07), dt: 0.60, eta: 8.2h
epoch 0 | batch step 800 | iter 800, batch cost: 0.33793, loss nan
2022-07-30 15:52:37,415-INFO: iter: 820, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.5873, cls: 1.5835, iou: nan), misc (ry: 1.62, z: 2.84), dt: 0.60, eta: 8.2h
epoch 0 | batch step 820 | iter 820, batch cost: 0.33032, loss nan
2022-07-30 15:52:47,838-INFO: iter: 840, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 3.0223, cls: 1.5823, iou: nan), misc (ry: 1.65, z: 3.18), dt: 0.60, eta: 8.2h
epoch 0 | batch step 840 | iter 840, batch cost: 0.45120, loss 6.283
2022-07-30 15:52:58,364-INFO: iter: 860, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.9777, cls: 1.5726, iou: nan), misc (ry: 1.62, z: 2.92), dt: 0.60, eta: 8.2h
epoch 0 | batch step 860 | iter 860, batch cost: 0.36559, loss nan
2022-07-30 15:53:07,671-INFO: iter: 880, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.9200, cls: 1.5723, iou: nan), misc (ry: 1.72, z: 2.72), dt: 0.59, eta: 8.1h
epoch 0 | batch step 880 | iter 880, batch cost: 0.34666, loss nan
2022-07-30 15:53:18,026-INFO: iter: 900, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 3.0423, cls: 1.5669, iou: nan), misc (ry: 1.59, z: 2.94), dt: 0.59, eta: 8.1h
epoch 0 | batch step 900 | iter 900, batch cost: 0.29486, loss nan
2022-07-30 15:53:27,427-INFO: iter: 920, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.9947, cls: 1.5652, iou: nan), misc (ry: 1.70, z: 3.53), dt: 0.59, eta: 8.0h
epoch 0 | batch step 920 | iter 920, batch cost: 0.40898, loss 4.630
2022-07-30 15:53:36,758-INFO: iter: 940, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.6468, cls: 1.5581, iou: nan), misc (ry: 1.57, z: 3.04), dt: 0.59, eta: 8.0h
epoch 0 | batch step 940 | iter 940, batch cost: 0.40008, loss nan
2022-07-30 15:53:46,767-INFO: iter: 960, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.8736, cls: 1.5566, iou: nan), misc (ry: 1.70, z: 2.76), dt: 0.59, eta: 8.0h
epoch 0 | batch step 960 | iter 960, batch cost: 0.38730, loss nan
2022-07-30 15:53:56,973-INFO: iter: 980, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.7599, cls: 1.5565, iou: nan), misc (ry: 1.67, z: 3.30), dt: 0.58, eta: 8.0h
epoch 0 | batch step 980 | iter 980, batch cost: 0.46592, loss nan
2022-07-30 15:54:06,729-INFO: iter: 1000, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.8967, cls: 1.5525, iou: nan), misc (ry: 1.56, z: 3.01), dt: 0.58, eta: 7.9h
epoch 0 | batch step 1000 | iter 1000, batch cost: 0.31681, loss nan
2022-07-30 15:54:16,163-INFO: iter: 1020, acc (bg: 1.00, fg: 0.00, iou: nan), loss (bbox_3d: 2.8384, cls: 1.5481, iou: nan), misc (ry: 1.46, z: 3.05), dt: 0.58, eta: 7.9h

6. 模型评估

评估默认配置:output/kitti_3d_multi_warmup/conf.pkl

pkl配置
  model: 		"model_3d_dilate"
  solver_type: 		"sgd"
  lr: 			0.004
  momentum: 		0.9
  max_iter: 		50000
  snapshot_iter: 	10000
  do_test: 		"True"
  test_scale: 		512
  crop_size: 		[512, 1760]
  mirror_prob: 		0.5
  distort_prob: 	-1
  dataset_test: 	"kitti_split1"
  datasets_train:
    name: 			"kitti_split1"
    anno_fmt: 			"kitti_det"
    im_ext: 			".png"
    scale: 			1

可视化 pkl 配置文件方法

In [ ]

import pickle
import numpy as np

PKL_PATH = 'output/kitti_3d_multi_warmup/conf.pkl'

f = open(PKL_PATH,'rb')
data = pickle.load(f)
print(data)
{'model': 'model_3d_dilate', 'solver_type': 'sgd', 'lr': 0.004, 'momentum': 0.9, 'weight_decay': 0.0005, 'max_iter': 50000, 'snapshot_iter': 10000, 'display': 20, 'do_test': True, 'lr_policy': 'poly', 'lr_steps': None, 'lr_target': 4e-08, 'rng_seed': 2, 'cuda_seed': 2, 'image_means': [0.485, 0.456, 0.406], 'image_stds': [0.229, 0.224, 0.225], 'feat_stride': 16, 'has_3d': True, 'test_scale': 512, 'crop_size': [512, 1760], 'mirror_prob': 0.5, 'distort_prob': -1, 'dataset_test': 'kitti_split1', 'datasets_train': [{'name': 'kitti_split1', 'anno_fmt': 'kitti_det', 'im_ext': '.png', 'scale': 1}], 'use_3d_for_2d': True, 'percent_anc_h': [0.0625, 0.75], 'min_gt_h': 32.0, 'max_gt_h': 384.0, 'min_gt_vis': 0.65, 'ilbls': ['Van', 'ignore'], 'lbls': ['Car', 'Pedestrian', 'Cyclist'], 'batch_size': 2, 'fg_image_ratio': 1.0, 'box_samples': 0.2, 'fg_fraction': 0.2, 'bg_thresh_lo': 0, 'bg_thresh_hi': 0.5, 'fg_thresh': 0.5, 'ign_thresh': 0.5, 'best_thresh': 0.35, 'nms_topN_pre': 3000, 'nms_topN_post': 40, 'nms_thres': 0.4, 'clip_boxes': False, 'test_protocol': 'kitti', 'test_db': 'kitti', 'test_min_h': 0, 'min_det_scales': [0, 0], 'cluster_anchors': 0, 'even_anchors': 0, 'expand_anchors': 0, 'anchors': array([[-5.00000000e-01, -8.50000000e+00,  1.55000000e+01,
         2.35000000e+01,  5.19685035e+01,  5.30606061e-01,
         1.71303030e+00,  1.02484848e+00, -7.99062281e-01],
       [-8.50000000e+00, -8.50000000e+00,  2.35000000e+01,
         2.35000000e+01,  5.21756230e+01,  1.61817308e+00,
         1.60048077e+00,  3.81134615e+00, -4.52645995e-01],
       [-1.65000000e+01, -8.50000000e+00,  3.15000000e+01,
         2.35000000e+01,  4.83343975e+01,  1.64394636e+00,
         1.52877395e+00,  3.96643678e+00,  6.72665133e-01],
       [-2.52760863e+00, -1.25552168e+01,  1.75276089e+01,
         2.75552177e+01,  4.47814414e+01,  5.33859649e-01,
         1.77070175e+00,  9.71052632e-01,  9.29950244e-02],
       [-1.25552168e+01, -1.25552168e+01,  2.75552177e+01,
         2.75552177e+01,  4.47042842e+01,  1.59936232e+00,
         1.56904348e+00,  3.81408696e+00, -1.86627213e-01],
       [-2.25828266e+01, -1.25552168e+01,  3.75828247e+01,
         2.75552177e+01,  4.34917372e+01,  1.62085791e+00,
         1.53588472e+00,  3.91044236e+00,  7.18576848e-01],
       [-5.06911659e+00, -1.76382332e+01,  2.00691166e+01,
         3.26382332e+01,  3.46664975e+01,  5.60784314e-01,
         1.75183007e+00,  9.67189542e-01, -3.83682971e-01],
       [-1.76382332e+01, -1.76382332e+01,  3.26382332e+01,
         3.26382332e+01,  3.53501396e+01,  1.56719697e+00,
         1.59113636e+00,  3.80957071e+00, -5.10569112e-01],
       [-3.02073498e+01, -1.76382332e+01,  4.52073517e+01,
         3.26382332e+01,  3.71276897e+01,  1.60199566e+00,
         1.52918655e+00,  3.90407809e+00,  4.52309814e-01],
       [-8.25477314e+00, -2.40095463e+01,  2.32547722e+01,
         3.90095444e+01,  2.87706550e+01,  6.13238434e-01,
         1.76021352e+00,  9.79964413e-01,  6.65861306e-02],
       [-2.40095463e+01, -2.40095463e+01,  3.90095444e+01,
         3.90095444e+01,  2.83305319e+01,  1.54346232e+00,
         1.59211813e+00,  3.65965377e+00, -8.10724630e-01],
       [-3.97643166e+01, -2.40095463e+01,  5.47643166e+01,
         3.90095444e+01,  3.05408887e+01,  1.62575851e+00,
         1.52380805e+00,  3.90750258e+00,  3.11866703e-01],
       [-1.22478371e+01, -3.19956741e+01,  2.72478371e+01,
         4.69956741e+01,  2.30111798e+01,  6.06449275e-01,
         1.75815217e+00,  9.95905797e-01,  2.07553162e-01],
       [-3.19956741e+01, -3.19956741e+01,  4.69956741e+01,
         4.69956741e+01,  2.29484396e+01,  1.51036697e+00,
         1.59892202e+00,  3.41869266e+00, -1.07553604e+00],
       [-5.17435112e+01, -3.19956741e+01,  6.67435074e+01,
         4.69956741e+01,  2.50000350e+01,  1.62773519e+00,
         1.52728223e+00,  3.91656214e+00,  3.33727076e-01],
       [-1.72529469e+01, -4.20058937e+01,  3.22529488e+01,
         5.70058937e+01,  1.84786287e+01,  6.01377953e-01,
         1.74685039e+00,  1.00673228e+00,  3.47205080e-01],
       [-4.20058937e+01, -4.20058937e+01,  5.70058937e+01,
         5.70058937e+01,  1.88152861e+01,  1.48707424e+00,
         1.59864629e+00,  3.33720524e+00, -8.61836265e-01],
       [-6.67588425e+01, -4.20058937e+01,  8.17588425e+01,
         5.70058937e+01,  2.05755766e+01,  1.62259012e+00,
         1.53216288e+00,  3.94210948e+00,  3.22827453e-01],
       [-2.35266075e+01, -5.45532150e+01,  3.85266075e+01,
         6.95532150e+01,  1.50346916e+01,  6.24688797e-01,
         1.74406639e+00,  9.16597510e-01,  4.09687668e-01],
       [-5.45532150e+01, -5.45532150e+01,  6.95532150e+01,
         6.95532150e+01,  1.53459774e+01,  1.28964912e+00,
         1.65929825e+00,  3.08289474e+00, -2.75262084e-01],
       [-8.55798264e+01, -5.45532150e+01,  1.00579826e+02,
         6.95532150e+01,  1.63260674e+01,  1.61281972e+00,
         1.52656394e+00,  3.93422188e+00,  2.68282700e-01],
       [-3.13903351e+01, -7.02806702e+01,  4.63903351e+01,
         8.52806702e+01,  1.22645452e+01,  6.30507614e-01,
         1.74746193e+00,  9.53908629e-01,  3.16888822e-01],
       [-7.02806702e+01, -7.02806702e+01,  8.52806702e+01,
         8.52806702e+01,  1.18782324e+01,  1.04448598e+00,
         1.67046729e+00,  2.41532710e+00, -2.10551319e-01],
       [-1.09171005e+02, -7.02806702e+01,  1.24171005e+02,
         8.52806702e+01,  1.35797114e+01,  1.62121377e+00,
         1.53905797e+00,  3.96132246e+00,  1.88529752e-01],
       [-4.12471313e+01, -8.99942627e+01,  5.62471313e+01,
         1.04994263e+02,  9.93203179e+00,  6.10459770e-01,
         1.77063218e+00,  9.33965517e-01,  4.85936169e-01],
       [-8.99942627e+01, -8.99942627e+01,  1.04994263e+02,
         1.04994263e+02,  8.94859682e+00,  8.11129032e-01,
         1.76588710e+00,  1.66185484e+00,  7.96113709e-02],
       [-1.38741394e+02, -8.99942627e+01,  1.53741394e+02,
         1.04994263e+02,  1.10426775e+01,  1.61037123e+00,
         1.53266821e+00,  3.89865429e+00,  3.95016718e-02],
       [-5.36021461e+01, -1.14704292e+02,  6.86021423e+01,
         1.29704285e+02,  8.38891072e+00,  6.04338624e-01,
         1.79275132e+00,  9.49629630e-01,  8.05747649e-01],
       [-1.14704292e+02, -1.14704292e+02,  1.29704285e+02,
         1.29704285e+02,  8.07141465e+00,  1.01018519e+00,
         1.75148148e+00,  2.18962963e+00, -7.64042511e-02],
       [-1.75806442e+02, -1.14704292e+02,  1.90806442e+02,
         1.29704285e+02,  9.18446466e+00,  1.60643243e+00,
         1.52597297e+00,  3.86878378e+00, -6.56303932e-02],
       [-6.90885468e+01, -1.45677094e+02,  8.40885468e+01,
         1.60677094e+02,  6.92342698e+00,  6.26693548e-01,
         1.79145161e+00,  9.60241935e-01,  7.83861491e-01],
       [-1.45677094e+02, -1.45677094e+02,  1.60677094e+02,
         1.60677094e+02,  6.78398804e+00,  1.38389610e+00,
         1.61545455e+00,  2.86194805e+00, -1.03479279e+00],
       [-2.22265656e+02, -1.45677094e+02,  2.37265656e+02,
         1.60677094e+02,  7.86264322e+00,  1.61697095e+00,
         1.54979253e+00,  3.94771784e+00, -7.12191598e-02],
       [-8.85000000e+01, -1.84500000e+02,  1.03500000e+02,
         1.99500000e+02,  5.18910968e+00,  6.60465116e-01,
         1.75534884e+00,  8.40930233e-01,  1.72643724e-01],
       [-1.84500000e+02, -1.84500000e+02,  1.99500000e+02,
         1.99500000e+02,  4.38755254e+00,  7.42857143e-01,
         1.72785714e+00,  1.38142857e+00,  6.42191569e-01],
       [-2.80500000e+02, -1.84500000e+02,  2.95500000e+02,
         1.99500000e+02,  5.58339956e+00,  1.58328358e+00,
         1.54699360e+00,  3.86228145e+00, -7.18690361e-02]]), 'bbox_means': array([[-0.00022546,  0.00160404,  0.06383215, -0.09315256,  0.01069604,
        -0.06744095,  0.19155604,  0.05884239, -0.02122913,  0.06871941,
        -0.00352113]]), 'bbox_stds': array([[0.13962965, 0.1255247 , 0.24738377, 0.23853353, 0.16330168,
        0.13235298, 3.62072376, 0.38246312, 0.10154974, 0.50257567,
        1.85493732]]), 'anchor_scales': array([ 32.        ,  40.1104343 ,  50.27646685,  63.01909126,
        78.99134748,  99.01178916, 124.10643323, 155.56134174,
       194.98853052, 244.40858256, 306.35419975, 384.        ]), 'anchor_ratios': array([0.5, 1. , 1.5]), 'hard_negatives': True, 'focal_loss': 0, 'cls_2d_lambda': 1, 'iou_2d_lambda': 1, 'bbox_2d_lambda': 0, 'bbox_3d_lambda': 1, 'bbox_3d_proj_lambda': 0.0, 'hill_climbing': True, 'pretrained': 'pretrained_model/densenet.pdparams', 'visdom_port': 8100}

In [ ]

%cd ~/M3D-RPN-2.0
! python test.py \
  --conf_path output/kitti_3d_multi_warmup/conf.pkl \
  --weights_path output/kitti_3d_multi_warmup/weights/iter50000.0_params.pdparams
/home/aistudio/M3D-RPN-2.0
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
W0606 16:20:25.278681   659 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0606 16:20:25.283661   659 device_context.cc:465] device: 0, cuDNN Version: 7.6.
loaded model from  output/kitti_3d_multi_warmup/weights/iter50000.0_params.pdparams
start evaluation...
2022-06-06 16:23:50,091-INFO: testing 1000/3769, dt: 0.201, eta: 9.3m
2022-06-06 16:27:11,090-INFO: testing 2000/3769, dt: 0.201, eta: 5.9m
2022-06-06 16:30:31,986-INFO: testing 3000/3769, dt: 0.201, eta: 2.6m
Evaluation Finished!

预模型 output/kitti_3d_multi_warmup/weights/iter50000.0_params.pdparams 效果

Car

Easy Mod Hard
2D detection 87.27 81.74 66.60
3D BEV 24.93 18.58 16.69
3D detection 19.10 15.69 13.15

Ped

Easy Mod Hard
2D detection 72.47 58.28 50.07
3D BEV 4.12 4.55 3.44
3D detection 3.77 3.45 3.07

Cyclist

Easy Mod Hard
2D detection 63.97 45.97 39.73
3D BEV 11.72 10.16 10.16
3D detection 10.56 10.07 10.07

7. 模型推理

推理过程包括两个步骤:1)导出推理模型 2)执行推理代码

导出推理模型

PaddlePaddle 框架保存的权重文件分为两种:支持前向推理和反向梯度的训练模型和只支持前向推理的推理模型。二者的区别是推理模型针对推理速度和显存做了优化,裁剪了一些只在训练过程中才需要的 tensor,降低显存占用,并进行了一些类似层融合,kernel 选择的速度优化。因此可执行如下命令导出推理模型。

In [14]

! python export_model.py \
  --conf_path output/kitti_3d_multi_warmup/conf.pkl \
  --weights_path output/kitti_3d_multi_warmup/weights/iter50000.0_params.pdparams
 
 

生成的推理模型位于 inference 目录,里面包含三个文件,分别为

  • inference.pdmodel
  • inference.pdiparams
  • inference.pdiparams.info。

其中 inference.pdmodel 用来存储推理模型的结构, inference.pdiparams 和 inference.pdiparams.info 用来存储推理模型相关的参数信息。

结果保存在 inference_result 目录下。

In [16]

! python infer.py \
  --conf_path output/kitti_3d_multi_warmup/conf.pkl
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
W0607 08:43:10.540436 17651 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0607 08:43:10.544229 17651 device_context.cc:465] device: 0, cuDNN Version: 7.6.

In [15]

! python vis.py

猜你喜欢

转载自blog.csdn.net/qq_34106574/article/details/126073626