想必大家都有这样一种感受,明明我自己学习了tensorflow很久了,但是就是觉得自己无法灵活运用,今天大家就跟着我的脚步来学习,Tensorflow的验证码识别学习
本文章代码借鉴这位仁兄,并且改进了其中的一些问题,欢迎大家来学习讨论!
链接:https://zhuanlan.zhihu.com/p/36979787
好的,废话不多说我们开始!!!
1、整理并生成自己的数据集
关于TensorFlow读取数据,官网给出了三种方法:
- 供给数据(Feeding):在TensorFlow程序运行的每一步,让python代码来供给数据。
- 从文件读取数据:在TensorFlow图的起始,让一个输入管线从文件中读取数据。
- 预加载数据:在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yeild 使用更为简洁)。但是如果数据量较大,这样的方法就不适用了。因为太耗内存,所以这时最好使用TensorFlow提供的队列queue,也就是第二种方法:从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这里我们学习一种比较通用的,高效的读取方法,即使用TensorFlow内定标准格式——TFRecords。
我们这里采用Tfrecord方式制作数据集
1)、建立合适的路径:…/dataset/train(test)/0(1,2,3,4,5,6,7,8,9)
按照此先建立dataset下train,test两个文件夹,然后train,test文件夹下各建0-9的10个文件夹
文件夹的建立可在看完代码后自己更改
2)、利用captcha库生成验证码图片
新建create_images.py文件,执行
from captcha.image import ImageCaptcha
import os
def gen_captcha_text_and_image(path, i, a):
image_ = ImageCaptcha()
captcha_text = str(i)
print(captcha_text)
#captcha = image_.generate(captcha_text)
image_.write(a, path + captcha_text + '.png', format='png')
def train_images():
for i in range(2000):
for j in range(10):
path = "./dataset/train/"
gen_captcha_text_and_image(path + str(j) +'/', i, str(j))
i = i + 1
def test_images():
for i in range(1000):
for j in range(10):
path = "./dataset/test/"
gen_captcha_text_and_image(path + str(j) +'/', i, str(j))
i = i + 1
train_images()
test_images()
这样train目录下的目录名为0-9的10个文件夹每个文件夹下都有2000张对应的图片
test目录下的0-9的10个文件夹为1000张
3)、生成Tfrecord文件,,并且对应好图片相应的标签
新建一个image_to_tfrecord.py文件,执行
这里一定要记住,因为os.listdir()不会对目录中的文件进行排序,是随机读取目录,所以需要先去掉文件名的后缀然后进行排序,否则训练集的标签和图片信息不匹配,训练就没有任何意义
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
cwd = "./dataset/train/"
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
writer= tf.python_io.TFRecordWriter("0_to_9_train.tfrecords") #要生成的文件
for index, name in enumerate(classes):
classes_path = cwd + name +'/'
# 一下的三行是用于对目录的排序,因为os.listdir()不会对目录进行排序,所以标签无法对应上,会导致训练出现问题
rootpath = os.listdir(classes_path)
rootpath.sort(key=lambda x: int(x[:-4]))
for img_name in rootpath:
print("img_name", img_name)
img_path = classes_path + img_name #生成一个带有所有图片文件名的列表
print("img_path:", img_path)
img = Image.open(img_path)
img_raw = img.tobytes() #将图片变成二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) # example对象对label和image数据进行封装
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close()
4)、读取Tfrecord标签
新建一个tfrecord_to_image.py文件,不需要执行,等待主文件调用
这里没有什么问题
import numpy as np
import tensorflow as tf
def read_and_decode(filename): # 读入tfrecords
filename_queue = tf.train.string_input_producer([filename], shuffle=True) # 生成一个queue队列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
}) # 将image数据和label取出来
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [120, 120, 3])
img = tf.cast(img, tf.float32) # 在流中抛出img张量
label = tf.cast(features['label'], tf.int32) # 在流中抛出label张量
return img, label
至此我们的前期准备已经完成了
2、简化版AlexNet网络结构训练
CNN部分的网络我们这里不做讲解,这篇文章更多的是讲解一个训练模型的建立流程
注意这一行代码
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=batch_size, capacity=20000, # 因为我们这里有两万张图片,每个数字有2000张所以需要选择20000,不然会导致缺失数字
min_after_dequeue=1)
- min_after_dequeue —队列中文件的最小数量
- capacity — 队列中的文件的最大数量
特别注意:我们这里的训练集有两万张图片,进入队列后要想完整的读取到每一个数字,需要将capacity设置到大于20000,不然会缺失数据集
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import time
import tfrecord_to_image
# https://www.cnblogs.com/cvtoEyes/p/8981994.html
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# https://github.com/yhlleo/tensorflow.cifar10/blob/master/cifar10_input.py
batch_size = 10
def weight_varialbe(shape):
return tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1))
def bias_variable(shape):
return tf.constant(0.1, shape=shape)
def conv_2d(input, w):
return tf.nn.relu(tf.nn.conv2d(input, w, strides=[1, 1, 1, 1], padding='VALID'))
def max_pool(input):
return tf.nn.max_pool(input, ksize=[1, 5, 5, 1], strides=[1, 2, 2, 1], padding='VALID')
# 全连接层
def full_connect(input, output_depth):
input_depth = input.get_shape().as_list()[-1]
# print(input.get_shape().as_list())
w = weight_varialbe([input_depth, output_depth])
# print(w.get_shape().as_list())
b = bias_variable([output_depth])
# print('ff')
# print(b.get_shape().as_list())
fc = tf.nn.bias_add(tf.matmul(input, w), b)
return tf.nn.relu(fc)
def full_connect_final(input, output_depth):
input_depth = input.get_shape().as_list()[-1]
fc = tf.nn.bias_add(tf.matmul(input, weight_varialbe([input_depth, output_depth])), bias_variable([output_depth]))
return fc
# 第一层卷积池化层
# 下载数据
train_imgs = tf.placeholder(tf.float32, [batch_size, 120, 120, 3])
train_labels = tf.placeholder(tf.int32, [batch_size])
# 第一层卷积
con_w1 = weight_varialbe([11, 11, 3, 64])
net = conv_2d(train_imgs, con_w1)
print("第一层卷积:", net)
# 第二层池化
net = max_pool(net)
print("第二层池化:", net)
# 第三层卷积
con_w2 = weight_varialbe([11, 11, 64, 64])
net = conv_2d(net, con_w2)
print("第三层卷积:", net)
# 第四层卷积
con_w3 = weight_varialbe([11, 11, 64, 64])
net = conv_2d(net, con_w3)
print("第四层卷积:", net)
# 第五层池化
net = max_pool(net)
print("第五层池化:", net)
net = tf.reshape(net, [-1, 15*15*64])
# 第五六层全连接层
net = full_connect(net, 384)
# 第七层全连接层
net = full_connect(net, 192)
# 第八层全连接层
net = full_connect_final(net, 10)
print("net shape:", net.shape)
train_loss = tf.losses.sparse_softmax_cross_entropy(labels=train_labels, logits=net)
lr = 0.0001
opt = tf.train.AdamOptimizer(lr)
train_op = opt.minimize(train_loss)
predict = tf.argmax(net, axis=-1,output_type=tf.int32)
train_acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(net, axis=-1,output_type=tf.int32), train_labels), tf.float32))
img, label = tfrecord_to_image.read_and_decode("0_to_9_train.tfrecords")
img_test, label_test = tfrecord_to_image.read_and_decode("0_to_9_test.tfrecords")
#使用shuffle_batch可以随机打乱输入
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=batch_size, capacity=20000, # 因为我们这里有两万张图片,每个数字有2000张所以需要选择20000,不然会导致缺失数字
min_after_dequeue=1)
print("img.shape", img_batch.shape)
print("label.shape", label_batch.shape)
# img_test, label_test = tf.train.shuffle_batch([img_test, label_test],
# batch_size=batch_size, capacity=6000,
# min_after_dequeue=1000)
init = tf.global_variables_initializer()
sess = tf.InteractiveSession()
# gragh_writer = tf.summary.FileWriter('.', sess.graph)
sess.run(init)
coord = tf.train.Coordinator() #创建一个协调器,管理线程
threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。
for i in range(2000):
batch0, batch1 = sess.run([img_batch, label_batch])
sess.run(train_op, feed_dict={train_imgs : batch0, train_labels : batch1})
if i%10 == 0:
get_train_loss = sess.run(train_loss, feed_dict={train_imgs : batch0, train_labels : batch1})
train_accuy = sess.run(train_acc, feed_dict={train_imgs : batch0, train_labels : batch1})
print("train_loss", get_train_loss)
print("STEP %d Accuray %g" % (i, train_accuy))
print(sess.run(predict, feed_dict={train_imgs : batch0, train_labels : batch1}))
print(batch1)
fig, axes = plt.subplots(ncols=5, nrows=2)
for ind, (image, label) in enumerate(zip(batch0, batch1)):
image = image / 255.
row = ind // 5
col = ind % 5
axes[row][col].imshow(image, cmap='gray') # 灰度图
axes[row][col].axis('off')
axes[row][col].set_title('%d' % label)
plt.show()
展示:
这里为了检查读取状态是否成功我画了部分的图,如果大家只想训练的话,请删减以下代码:
fig, axes = plt.subplots(ncols=5, nrows=2)
for ind, (image, label) in enumerate(zip(batch0, batch1)):
image = image / 255.
row = ind // 5
col = ind % 5
axes[row][col].imshow(image, cmap='gray') # 灰度图
axes[row][col].axis('off')
axes[row][col].set_title('%d' % label)
plt.show()
由于电脑搭载cpu版本tensorflow,而且生成验证码的图片大小格式为(120,120,3),当batch_size取很大时,我的电脑跑不动,但是程序步骤没有问题
如果有很多人的电脑和我一样差的话,提供以下解决方法:
1、更加简化AlexNet网络模型,减去部分卷积层
2、batch_size取小一点
3、生成更小的验证码图片,具体的操作如下
将鼠标放置到ImageCaptcha()处,进去该对象的定义处
修改此处的width和height即可,但是需要width==height
希望大家一起进步!!!!!