之前的一篇博文(https://blog.csdn.net/foreseerwang/article/details/80170210)介绍了使用Tensorflow Dataset进行数据导入的方法及其优势。最近在实际使用中越发感觉到这个方式非常好用,尤其是发现了.from_generator这个method。
关于Dataset.from_generator的简单介绍,请参见如下两个链接:
https://tensorflow.google.cn/versions/master/api_docs/python/tf/data/Dataset#repeat
https://blog.csdn.net/dqcfkyqdxym3f8rb0/article/details/79342369
注意,Dataset.from_generator在旧版Tensorflow中没有,起码在1.3版本tf.contrib.data.Dataset中还没有,后来用的1.7版本就有了。
我们知道,tensorflow的基本原理是先构造一个计算图,最后再统一计算。为此,tf重写了几乎所有常见函数,用于构造计算图,而且tensorflow不支持循环、选择等普通编程语言的常见操作。这就给编程使用带来比较大的麻烦。具体到data feeding上,也是如此。虽然设计了placeholder、train.slice_input_producer系列、Dataset等多种方式,但使用中仍有各种不便,尤其是在输入形式复杂、需要多重变换的时候更是如此。而Dataset.from_generator可以在一定程度上解决这个问题。
简单的说,Dataset.from_generator可以使用普通编程语言编写的外部子函数生成Dataset,这样几乎不受tensorflow编程不便的影响。先举一个最简单的示例:
# demo of Dataset.from_generator
# blog.csdn.net/foreseerwang
# QQ: 50834
"""
Expected outputs:
Batch No. 0:
[0 1 2 3]
Batch No. 1:
[4 0 1 2]
Batch No. 2:
[3 4 0 1]
Batch No. 3:
[2 3 4]
end!
"""
import numpy as np
import tensorflow as tf
def data_generator():
dataset = np.array(range(5))
for d in dataset:
yield d
dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32), (tf.TensorShape([])))
dataset = dataset.repeat(3)
dataset = dataset.batch(4)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
try:
batch_num=0
while True:
one_batch = sess.run(one_element)
print('Batch No. %d:' % batch_num)
print(one_batch)
print('')
batch_num+=1
except tf.errors.OutOfRangeError:
print('end!')
很显然,这个的输出如下:
Batch No. 0:
[0 1 2 3]
Batch No. 1:
[4 0 1 2]
Batch No. 2:
[3 4 0 1]
Batch No. 3:
[2 3 4]
end!
下面给出一个复杂的问题。假设需要输入如下序列:
A B
A C B
C
…
其中A/B/C分别代表一个文件,例如一张图片或是一个文本文件。每一行是一条记录,按行读入,并聚集多行形成batch,譬如每4行形成一个batch。这里有两个难点:1.每一行/每一条记录的元素长度不一样;2.读入元素A/B/C之后还要以之作为文件名读入文件内容。现有各种data feeding方式似乎很难同时解决这两个难点,除了Dataset.from_generator。
针对这个问题,使用Dataset.from_generator的一个简化版示例如下:
# demo of Dataset.from_generator
# blog.csdn.net/foreseerwang
# QQ: 50834
"""
Expected outputs:
Batch No. 0:
[[ 1 2 3]
[ 2 3 -1]]
Batch No. 1:
[[ 3 -1 -1]
[ 4 5 -1]]
Batch No. 2:
[[ 6 7 8]
[ 9 -1 -1]]
Batch No. 3:
[[10 11 12]
[13 14 -1]]
Batch No. 4:
[[15 -1 -1]]
end!
"""
import io
import numpy as np
import tensorflow as tf
class DataFeeder:
def __init__(self, filenames):
self.filenames = filenames
def file_readline(self):
for filename in self.filenames:
fr = io.open(filename, 'r', encoding='utf-8')
while True:
file_line = fr.readline()
if not file_line:
break
datalist = file_line.split()
# if datalist is a list of filename, file contents can
# be read and appendded here.
yield np.asarray(datalist, dtype='int32')
fr.close()
def generate_batch(self, batch_size, num_epochs=None):
dataset = tf.data.Dataset.from_generator(self.file_readline,
tf.int32,
tf.TensorShape([None]))
dataset = dataset.repeat(num_epochs)
dataset = dataset.padded_batch(
batch_size,
padded_shapes=tf.TensorShape([3]),
padding_values=-1)
iterator = dataset.make_one_shot_iterator()
out_batch = iterator.get_next()
return out_batch
filenames = ['a.txt', 'b.txt', 'c.txt']
data_feeder = DataFeeder(filenames)
one_batch = data_feeder.generate_batch(batch_size=2, num_epochs=1)
with tf.Session() as sess:
try:
batch_num = 0
while True:
data_batch = sess.run(one_batch)
print('Batch No. %d:' % batch_num)
print(data_batch)
print('')
batch_num+=1
except tf.errors.OutOfRangeError:
print('end!')
其中三个文本文件a.txt/b.txt/c.txt的内容分别如下:
a.txt:
1 2 3
2 3
3
b.txt:
4 5
6 7 8
9
c.txt:
10 11 12
13 14
15
运行以上代码的输出为:
Batch No. 0:
[[ 1 2 3]
[ 2 3 -1]]
Batch No. 1:
[[ 3 -1 -1]
[ 4 5 -1]]
Batch No. 2:
[[ 6 7 8]
[ 9 -1 -1]]
Batch No. 3:
[[10 11 12]
[13 14 -1]]
Batch No. 4:
[[15 -1 -1]]
end!
目前的输出,每个batch是batch_size * 3的矩阵。实际上,1~15的数字可以是某个图片的文件名,在file_readline()函数中读出这些数字后,可以继续读出这些文件的内容,并形成更高维度的Dataset输出,譬如:batch_size * img_size * img_size * img_channel的Dataset。
最后,说几点注意事项(详见代码):
1. generator函数不能有输入参数,但如果是class内的一个函数,可以使用self参数,这也是传递参数的一个手段;
2. 上述class中,建议传递文件名,在generator中打开处理再关闭,而不应该在外面打开(fr=open(filename, ‘r’)),然后把fr传递给generator读取。实践表明:后面这种方法形成的dataset不能repeat;
3. 因为序列不等长,在形成dataset batch时需要使用Dataset.padded_batch方法。