前言
网上关于VGG模型的文章有很多,有介绍算法本身的,也有代码实现,但是很多代码只给出了模型的结构实现,并不包含数据准备的部分,这让人很难愉快的将代码迁移自己的任务中。为此,这篇博客接下来围绕着如何使用VGG实现自己的图像分类任务,从数据准备到实验验证。代码基于Python与TensorFlow实现,模型结构采用VGG-16,并且将很少的出现算法和理论相关的东西。
数据准备
下载数据和转换代码
大多数人自己的训练数据,一般都是传统的图片形式,如.jpg,.png等等,而图像分类任务的话,这些图片的天然组织形式就是一个类别放在一个文件夹里,那么有啥大众化的数据集是这样的组织形式呢?TensorFlow的FlowersData,它下载下来是这个样子:
一共有五类,每一类中都有几百张图,我们把这些数据组织成TFrecord形式,对应的博客在这里,源码的github在这里,FlowersData数据集在这里。
有上面这三个东西之后,就可以生成TFrecord文件了。
组织图片数据
首先将FlowersData文件夹下的数据分成两个部分,训练数据和测试数据,我把原文件五个类别中都拿出大概100张图左右,数据的构成和路径如下:
生成训练TFrecord
#图片路径
cwd = 'F:\\flowersdata\\trainimages\\'
#文件路径
filepath = 'F:\\flowersdata\\tfrecord\\train\\'
classes=['daisy',
'dandelion',
'roses',
'sunflowers',
'tulips']
#tfrecords格式文件名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)
#tfrecords格式文件名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)
生成效果:
生成预测TFrecord
#图片路径
cwd = 'F:\\flowersdata\\testimages\\'
#文件路径
filepath = 'F:\\flowersdata\\tfrecord\\test\\'
classes=['daisy',
'dandelion',
'roses',
'sunflowers',
'tulips']
#tfrecords格式文件名
ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)
#tfrecords格式文件名
ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)
生成效果:
训练模型
初始权重与源码下载
VGG-16的初始权重我上传到了百度云,在这里下载;
VGG-16源码我上传到了github,在这里下载;
在源码中:
train_and_val.py文件是最终要执行的文件,它定了训练和预测的过程;
input_data.py是将上一步中生成的TFRecord文件组织成batch的过程;
VGG.py定义了VGG-16的网络结构;
tool.py是最底层,定义了一些卷积池化等操作。
训练模型
train_and_val.py文件修改:
if __name__=="__main__":
train()
#evaluate()
根据自己的路径修改:
#初始权重路径
pre_trained_weights = 'vgg16_pretrain/vgg16.npy'
#训练数据路径
train_data_dir = 'F:\\flowersdata\\tfrecord\\train\\traindata.tfrecords*'
test_data_dir =
#预测数据路径
'F:\\flowersdata\\tfrecord\\test\\testdata.tfrecords*'
#训练生成文件路径
train_log_dir = 'logs/train/'
#预测生成文件路径
val_log_dir = 'logs/val/'
根据自己的显存容量修改:
IMG_W = 224
IMG_H = 224
BATCH_SIZE = 8
训练过程每50个step打印loss;
每200个step计算一个batch中的准确率;
每1000个step保存一次权重。
预测
train_and_val.py文件修改:
if __name__=="__main__":
#train()
evaluate()
#训练过程中生成的权重
log_dir = 'logs/train/'
#预测数据集路径
test_data_dir = 'F:\\flowersdata\\tfrecord\\test\\testdata.tfrecords*'
#用于生成tf文件的图片数量
n_test = 502
打印测试样本总数;
打印正确预测的样本总数;
打印top_1。