如果 enqueue_many=False
,则 tensor
表示单个样本.
对于 shape 为 [x, y, z]
的输入 tensor,该函数输出为,shape 为 [batch_size, x, y, z]
的 tensor.
如果 enqueue_many=True
,则 tensors
表示 batch 个样本,其中,第一维表示样本的索引,所有的 tensors
都在第一维具有相同的尺寸.
对于 shape 为 [*, x, y, z]
的输入 tensor,该函数输出为,shape 为 [batch_size, x, y, z]
的 tensor.
一、上面的解释看了可能不会一下子懂。举个实际的例子,假设有a.csv文件有101个样本,shape为[101,95],第一列是label列,取值依次为1 2 3 4 ... 101
如果 SHAPE=10
key, value = reader.read_up_to(file_queue, num_records=SHAPE)
batch_values = tf.train.batch([value], batch_size=SHAPE, capacity=64000, enqueue_many=True)
那么每次sess.run()的时候,处理到的样本数都是10,并且是按照顺序来处理1.csv文件,处理10次后,第11次处理的样本是a.csv的最后一条样本+a.csv的最前面9条样本。所以会一直处理下去,不会停止
二、
batch_values = tf.train.batch([value], batch_size=SHAPE, capacity=64000, enqueue_many=False)
则tensor
表示单个样本,也就是read_up_to每次处理的是10个样本,系统将这10条样本当成单个样本,又因为batch_size=10,所以每次处理的是10*10个样本
即第一次就会处理前100条样本
# [[ 1 2 3 4 5 6 7 8 9 10]
# [ 11 12 13 14 15 16 17 18 19 20]
# [ 21 22 23 24 25 26 27 28 29 30]
# [ 31 32 33 34 35 36 37 38 39 40]
# [ 41 42 43 44 45 46 47 48 49 50]
# [ 51 52 53 54 55 56 57 58 59 60]
# [ 61 62 63 64 65 66 67 68 69 70]
# [ 71 72 73 74 75 76 77 78 79 80]
# [ 81 82 83 84 85 86 87 88 89 90]
# [ 91 92 93 94 95 96 97 98 99 100]]
而又因为,前面肯定有设定features.set_shape([10,94]),所以第二次处理的时候,肯定会报错
INFO:tensorflow:Error reported to Coordinator: <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>, Shape mismatch in tuple component 0. Expected [10,94], got [1,94]
然后程序停止。
补充:为什么前面必须设定features.set_shape([10,94])
因为在enqueue_many=False的时候,流处理数据的时候,如果是像下面这样,先把features labels给解析出来了,然后再输入到tf.train.batch函数里,就必需在两段代码中加上 features.set_shape([10,94]) labels.set_shape([10]) 指定两个变量的shape,否则直接报错说 连一次都不运行
ValueError: All shapes must be fully defined: [TensorShape([Dimension(None), Dimension(94)]), TensorShape([Dimension(10)])]
key, value = reader.read_up_to(file_queue, num_records=SHAPE)
record_defaults = [[1.0]]*95
vall = tf.decode_csv(value, record_defaults=record_defaults)
labels = tf.reshape(tf.cast(vall[0], tf.int32), [SHAPE])
features = tf.stack(vall[1:], axis=1)
# features.set_shape([10,94])
# labels.set_shape([10])
features, labels = tf.train.batch([features, labels], batch_size=SHAPE, capacity=64000, enqueue_many=False)
三、但是当enqueue_many=True的时候,指不指定
# features.set_shape([10,94])
# labels.set_shape([10])
都可以,不会影响到啥
四、
batch_values = tf.train.shuffle_batch([value], batch_size=SHAPE, capacity=64000, enqueue_many=True, min_after_dequeue=int(SHAPE / 2))
当训练集中是101个样本,SHAPE大小为10。可以保证在101个样本都被用过一遍后,才会出现一个样本被再次用到,而且会持续不断的迭代出数据下去!