(1)one shot iterator 单次迭代
仅支持对数据集进行一次迭代,不需要显式初始化。
import tensorflow as tf
# 通过对list切片的方式创建数据集
dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 6], [0, 0, 1, 1, 2, 2]))
# 打乱顺序,设置batch,设置重复次数
dataset = dataset.shuffle(buffer_size=1000).batch(4).repeat(2)
# 创建迭代器
itr = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
v = sess.run([itr])
print(v)
运行结果:
[(array([1, 2, 5, 6]), array([0, 0, 2, 2]))]
[(array([4, 3]), array([1, 1]))]
[(array([5, 4, 1, 6]), array([2, 1, 0, 2]))]
[(array([2, 3]), array([0, 1]))]
Traceback (most recent call last):
File "D:\App\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call
如果有101个样本,batch size是100,那么这个函数会返回两个batch,第一个是100,第二个是1。
如果想要强制每个batch 的数目是一样的那么就在batch 的时候设置drop_remainder=True
。这样101个样本,只会返回一个数目是100的batch,剩下的一个样本会被丢弃。
dataset.shuffle(buffer_size=1000).batch(128, drop_remainder=True).repeat(20000)
(2)initializable iterator 带参数的迭代器
这种迭代器允许使用placeholder,在运行时传送参数。
比如下面的例子,创建一个从0~p的数据集,p是参数,在运行时传入p。
import tensorflow as tf
# 对数据集传入参数 p
p = tf.placeholder(tf.int64, shape=[])
# 创建从0到p-1的数据集
dataset = tf.data.Dataset.range(p)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
# Initialize an iterator over a dataset with 3 elements.
sess.run(iterator.initializer, feed_dict={p: 3})
for i in range(3):
print(sess.run(next_element))
# Initialize the same iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={p: 10})
for i in range(10):
print(sess.run(next_element))
(3)可切换数据集的迭代器
这种情况下不同Dataset的尺寸要一致。
import tensorflow as tf
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(10).map(lambda x: x - 100).batch(5)
validation_dataset = tf.data.Dataset.range(10).batch(5)
# 根据training_dataset的形状创建迭代器
iterator = tf.data.Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
with tf.Session() as sess:
# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(5):
print("======")
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(2):
print(sess.run(next_element))
print("------")
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(1):
print(sess.run(next_element))
运行结果
======
[-100 -99 -98 -97 -96]
[-95 -94 -93 -92 -91]
------
[0 1 2 3 4]
======
[-100 -99 -98 -97 -96]
[-95 -94 -93 -92 -91]
------
[0 1 2 3 4]
====== (后面省略)
validation从未输出5~10,可以得出每一次执行初始化都会丢失之前的进度。
(4)可切换数据集并保留进度的迭代器
可feed迭代器可以与 tf.placeholder 一起使用,通过feed_dict机制选择每次调用 tf.Session.run 时所使用的 Iterator。它功能与(3)相同,但在迭代器之间切换时不需要从数据集的开头初始化迭代器。
import tensorflow as tf
p = tf.placeholder(tf.int64, shape=[])
# 定义训练数据集和验证数据集。两个数据集的结构相同。
training_dataset = tf.data.Dataset.range(10).map(lambda x: x - 100).batch(5).repeat(1000)
validation_dataset = tf.data.Dataset.range(p).batch(5).repeat(1000)
# 这种迭代器需要传入一个handle和数据集的结构
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# 创建两个不同的迭代器(训练迭代器和验证迭代器)
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
with tf.Session() as sess:
# 获得两个迭代器的句柄
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
sess.run(validation_iterator.initializer, feed_dict={p: 10})
while True:
print("======")
# 通过feed_dict=handle在两个迭代器之间切换。现在换成训练迭代器。
for _ in range(1):
print(sess.run(next_element, feed_dict={handle: training_handle}))
print("------")
# 现在换成迭代训练器。
for _ in range(1):
print(sess.run(next_element, feed_dict={handle: validation_handle}))
运行结果
======
[-100 -99 -98 -97 -96]
------
[0 1 2 3 4]
======
[-95 -94 -93 -92 -91]
------
[5 6 7 8 9]
======
[-100 -99 -98 -97 -96]
------
[0 1 2 3 4]
======
[-95 -94 -93 -92 -91]
------
[5 6 7 8 9] (后面的部分省略)
从运行结果可以看出,切换数据集时确实保留了上次的进度。