上一篇讲到了环境搭建,这一篇我将继续分析训练模型的过程
首先用的数据集是kitti格式的,刚开始我也是一脸蒙蔽,因为之前用的oc和coco等数据集格式,然后为就看了一下kitti里面的东西
一行有15个数据,然后我就去百度了一下kitti数据的参数解释,以及看了官网给的解释
具体如下
第1个字符串:代表物体类别
第2个数:代表物体是否被截断
从0(非截断)到1(截断)浮动,其中truncated指离开图像边界的对象
第3个数:代表物体是否被遮挡
整数0,1,2,3表示被遮挡的程度
0:完全可见 1:小部分遮挡 2:大部分遮挡 3:完全遮挡(unknown)
第4个数:alpha,物体的观察角度,范围:-pi~pi
第5~8这4个数:物体的2维边界框
xmin,ymin,xmax,ymax
第9~11这3个数:3维物体的尺寸
高、宽、长(单位:米)
第12~14这3个数:3维物体的位置
x,y,z(在照相机坐标系下,单位:米)
第15个数:3维物体的空间方向:rotation_y
在照相机坐标系下,物体的全局方向角(物体前进方向与相机坐标系x轴的夹角),范围:-pi~pi
然后我就寻思着是不是可以直接通过voc数据集转化过去呢?答案肯定是可以的
因为voc可以转到coco数据集,博主之前也弄过yolo4
博主从网上找了一些口罩的数据集(里面包含了label的标签),然后为通过如下的代码
import xml.etree.ElementTree as ET
import os
base_xml_dir = "./label1/"
xml_list = os.listdir(base_xml_dir)
kitti_saved_dir = "./label_2/"
def convert_annotation(file_name):
in_file = open(base_xml_dir + file_name)
tree = ET.parse(in_file)
root = tree.getroot()
with open(kitti_saved_dir + file_name[:-4] + '.txt', 'w') as f:
for obj in root.iter('object'):
cls = obj.find('name').text
xmlbox = obj.find('bndbox')
"""
第5~8这4个数:物体的2维边界框
xmin,ymin,xmax,ymax
"""
xmin, ymin, xmax, ymax = xmlbox.find('xmin').text, xmlbox.find('ymin').text, \
xmlbox.find('xmax').text, xmlbox.find('ymax').text
f.write(cls + " " + '0.0' + " " + '0' + " " + '1.0' + " " + str(xmin) + '.0' + " "
+ str(ymin) + '.0' + " " + str(xmax) + '.0' + " " + str(ymax) + '.0' + " " +
str((int(str(ymax)) - int(str(ymin)))/int(1000) )+ " " + str((int(str(xmax)) - int(str(xmin)) )/int(1000))+ " " + '0.1' + " " + '1.0' + " " + '0.0' + " " + '1.0' + " " + '0.0' + '\n')
for i in xml_list:
convert_annotation(i)
代码其实很简单,把我们voc数据集得不到的东西全部用来填入,但这里需要注意,填入的是整形还是浮点型(官网给的除了第一个是字符窗,第三个是整形其余到是浮点型)
至于voc数据集怎么制作,可以去看我的tensorflow objection api 的文章
至此数据集准备完毕,接下来我们需要看他的文件夹是怎么弄的
进入tlt-experiments,然后新建data文件夹
cd /workspace/tlt-experiments
sudo mkdir data
然后进入data文件,新建一个testing和training文件夹
文件夹名字可以自己改,后期我们可以自定义
image_2放的是图片
label_2放的是txt文件
接着我们进入jupyter(就算刚才打开的浏览器)
第三行的key需要我们自己输入,就是文章一里面提到的
然后直接跳到下图的位置
接着找到下图的文件夹,打开ssd_tfrecords_kitti_trainval.txt
kitti_config {
root_directory_path: "/workspace/tlt-experiments/data/training"
image_dir_name: "image_2" ##如果和我一样可以不用改,如果你上面自定义文件夹名字,这边需要改
label_dir_name: "label_2" ##
image_extension: ".png" #自己图片的格式
partition_mode: "random"
num_partitions: 2
val_split: 14
num_shards: 10
}
image_directory_path: "/workspace/tlt-experiments/data/training"
然后执行下方的三步,如果出错,赵找报错的原因,一般不会报错
然后执行这里的四步,这里的第三步可能下载有点慢
然后我们打开下方文件夹的!cat $SPECS_DIR/ssd_train_resnet18_kitti.txt,改最重要的东西
random_seed: 42
ssd_config {
aspect_ratios_global: "[1.0, 2.0, 0.5, 3.0, 1.0/3.0]"
scales: "[0.05, 0.1, 0.25, 0.4, 0.55, 0.7, 0.85]"
two_boxes_for_ar1: true
clip_boxes: false
loss_loc_weight: 0.8
focal_loss_alpha: 0.25
focal_loss_gamma: 2.0
variances: "[0.1, 0.1, 0.2, 0.2]"
arch: "resnet" ## 网络的类型,如果是mobilenet_v2,需要自己改
nlayers: 18
freeze_bn: false
}
training_config {
batch_size_per_gpu: 24 ## 建议改小,不然训练会报错
num_epochs: 80
learning_rate {
soft_start_annealing_schedule {
min_learning_rate: 5e-5
max_learning_rate: 2e-2
soft_start: 0.15
annealing: 0.5
}
}
regularizer {
type: L1
weight: 3e-06
}
}
eval_config {
validation_period_during_training: 10
average_precision_mode: SAMPLE
batch_size: 32 ## 建议改小,不然训练会报错
matching_iou_threshold: 0.5
}
nms_config {
confidence_threshold: 0.01
clustering_iou_threshold: 0.6
top_k: 200
}
augmentation_config {
preprocessing {
output_image_width: 1248
output_image_height: 384
output_image_channel: 3
crop_right: 1248
crop_bottom: 384
min_bbox_width: 1.0
min_bbox_height: 1.0
}
spatial_augmentation {
hflip_probability: 0.5
vflip_probability: 0.0
zoom_min: 0.7
zoom_max: 1.8
translate_max_x: 8.0
translate_max_y: 8.0
}
color_augmentation {
hue_rotation_max: 25.0
saturation_shift_max: 0.20000000298
contrast_scale_max: 0.10000000149
contrast_center: 0.5
}
}
dataset_config {
data_sources: {
tfrecords_path: "/workspace/tlt-experiments/data/tfrecords/kitti_trainval/kitti_trainval*"
image_directory_path: "/workspace/tlt-experiments/data/training"
}
image_extension: "png" #3 图片格式
## 改为自己的标签名字
target_class_mapping {
key: "have_mask"
value: "have_mask"
}
target_class_mapping {
key: "no_mask"
value: "no_mask"
}
validation_fold: 0
}
改完记得保存
然后运行下面的2步
然后你改了网络,那么下图的文件路径也要改
如果到这里都没有问题,那么下面的基本都是ok的
教程至此基本结束了,感谢支持