import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
1.从数组创建数据集dataset
input_data = np.arange(9)
dataset = tf.data.Dataset.from_tensor_slices(input_data)
print(type(dataset))
iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()
y = x*x
with tf.Session() as sess:
for i in range(len(input_data)):
print(sess.run(y))
<class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>
0
1
4
9
16
25
36
49
64
2.读取文本文件里面的数据
with open('./test_file_1.txt','w') as file:
file.write('test_file_1: This is the first line.\n')
file.write('test_file_1: This is the second line.\n')
with open('./test_file_2.txt','w') as file:
file.write('test_file_2: This is the third line.\n')
file.write('test_file_2: This is the fourth line.\n')
files = ['test_file_1.txt','test_file_2.txt']
dataset = tf.data.TextLineDataset(files)
iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
print(sess.run(x))
b'test_file_1: This is the first line.'
b'test_file_1: This is the second line.'
b'test_file_2: This is the third line.'
b'test_file_2: This is the fourth line.'
3. 解析TFRecord文件里的数据。读取文件为本章第一节创建的文件。
def parser(record):
features = tf.parse_single_example(
record,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([],tf.int64)
})
decoded_image = tf.decode_raw(features['image_raw'],tf.uint8)
retyped_image = tf.cast(decoded_image,tf.float32)
image = tf.reshape(retyped_image,[784])
label = tf.cast(features['label'],tf.int32)
return image, label
files = ['output.tfrecords']
dataset = tf.data.TFRecordDataset(files)
dataset = dataset.map(parser)
iterator = dataset.make_one_shot_iterator()
image, label = iterator.get_next()
fig = plt.figure(figsize=(10,5))
with tf.Session() as sess:
for i in range(10):
im,la = sess.run([image,label])
ax = fig.add_subplot(2,5,i+1)
ax.imshow(np.reshape(im,(28,28)))
ax.set_axis_off()
ax.set_title('Number %d'%la)
4. 使用initializable_iterator来动态初始化数据集。
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parser)
iterator = dataset.make_initializable_iterator()
image, label = iterator.get_next()
n = 0
with tf.Session() as sess:
sess.run(iterator.initializer,feed_dict={input_files:['output.tfrecords']})
while True:
try:
x,y = sess.run([image,label])
n=n+1
except tf.errors.OutOfRangeError:
break
print('The total sample number is %d.'%n)
The total sample number is 55000.