使用slim模块,快速简洁实现VGG16,并且实现基于VGG16 fine-tune(全部fine-tune、只fine-tune fc8层之前的层)
注意:由于slim模块slim.learning.train()与slim.learning.create_train_op() 不像平常的feed数据(没有占位符),需要slim模块中的tfrecords 生成并直接feed进去。很难理解,因此,我们这里只采用slim模块的一些函数,不采用slim模块的训练。
slim.get_model_variables('vgg_16') 这个函数会自动匹配开头所有符合关键字的 变量
1) 实现vgg16: all initializer randomly:
def vgg16(inputs,is_training):
network=slim.nets.vgg
net=network.vgg_16(inputs,1000,is_training=is_training)
return net
def vgg16(inputs,is_training):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
weights_regularizer=slim.l2_regularizer(0.0005)):
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [2, 2], scope='pool3')
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
net = slim.max_pool2d(net, [2, 2], scope='pool4')
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
net = slim.max_pool2d(net, [2, 2], scope='pool5')
net = slim.fully_connected(net, 4096, scope='fc6')
net = slim.dropout(net, 0.5, scope='dropout6',is_training=is_training)
net = slim.fully_connected(net, 4096, scope='fc7')
net = slim.dropout(net, 0.5, scope='dropout7',is_training=is_training)
net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc8')
return net
2)基于vgg16 fine-tune:
''' vgg16 '''
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.slim import nets
import numpy as np
import random
def one_hot(labels,class_num):
N=labels.shape[0]
one_hot=np.zeros([N,class_num],dtype=np.int32)
for i in range(N):
for j in range(class_num):
one_hot[i,j]=np.int32((labels[i]==j))
return one_hot
'''
More advanced tf functions: nets (Call already written functions)
'''
sess=tf.Session()
fine_tune_path=r'**/vgg_16.ckpt'
#reader=tf.train.NewCheckpointReader(fine_tune_path)
#key=reader.get_variable_to_shape_map() #查看 权重
class_num=21
# 如果要fine-tune 分类网络 (其中的FC也fine-tune,那么需要输入图片大小也是224*224),
#如果不fine-tune FC那么可以输入任意大小
image=tf.placeholder(tf.float32,shape=[None,224,224,3],name='image')
label=tf.placeholder(tf.int32,shape=[None,class_num],name='label')
is_training=tf.placeholder(tf.bool,name='is_training')
network=nets.vgg
net,end_points=network.vgg_16(image,class_num,is_training=is_training)# return two value: prediction and end_points
#init_fn=slim.assign_from_checkpoint_fn(fine_tune_path,slim.get_model_variables('vgg_16'))# 这是fine-tune全部权重
#print(net.get_shape().as_list())
#print(end_points.keys())
softmax=slim.nn.softmax(net+tf.constant(1e-4))
pred=tf.argmax(softmax,axis=-1)
tf.add_to_collection('pred',pred)
init_fn=None
# 只获取FC8之前的权重 这三行可以在sess.run(init_op) 之前任意位置 因为是将权重导入sess会话窗口
exclude = ['vgg_16/fc8'] # find it can auto match fc8's weights and biases slim模块的获取变量会自动识别 fc8下的weights biases
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
init_fn = slim.assign_from_checkpoint_fn(fine_tune_path, variables_to_restore)
''' slim.losses.softmaxloss and tf.nn.softmaxloss are not the same ,
the first is mean and the second need mean '''
loss=slim.losses.softmax_cross_entropy(logits=net,onehot_labels=label)# [batch_size num_class]
''' in the future , the slim.losses will be not used '''
train_op=tf.train.GradientDescentOptimizer(1e-4).minimize(loss)
init_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
if init_fn is not None:
print('ok')
init_fn(sess)
print('successful fine-tune')
saver=tf.train.Saver(max_to_keep=1)
batch_size=4
for step in range(1000):
imgs=np.random.random([batch_size,224,224,3])
print(imgs.shape)
labs=one_hot(np.array(random.sample(range(class_num),batch_size)),class_num)
sess.run(train_op,feed_dict={image:imgs,label:labs,is_training:True})
Loss=sess.run(loss,feed_dict={image:imgs,label:labs,is_training:True})
print(Loss)
if step%100==0:
saver.save(sess,r'./model/model.ckpt',global_step=step)