180512 tensorflow数据集tf.data.Dataset的基本操作

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

1.从数组创建数据集dataset

# Step-1: 数据集实例化  dataset = tf.data.Dataset.from_tensor_slices(your real data)
input_data = np.arange(9)
dataset = tf.data.Dataset.from_tensor_slices(input_data)
print(type(dataset))

# Step-2: 创建数据集的迭代器 dataset.make_one_shot_iterator(), 每次返回一个张量数据
iterator = dataset.make_one_shot_iterator()

# Step-3:创建数据遍历器get_next() 处理数据
x = iterator.get_next()
y = x*x

with tf.Session() as sess:
    for i in range(len(input_data)):
        print(sess.run(y))
<class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>
0
1
4
9
16
25
36
49
64

2.读取文本文件里面的数据

# Step-1: 创建文本文件
with open('./test_file_1.txt','w') as file:
    file.write('test_file_1: This is the first line.\n')
    file.write('test_file_1: This is the second line.\n')

with open('./test_file_2.txt','w') as file:
    file.write('test_file_2: This is the third line.\n')
    file.write('test_file_2: This is the fourth line.\n')


# Step-2: 实例化创建文本数据集Dataset
files = ['test_file_1.txt','test_file_2.txt']
dataset = tf.data.TextLineDataset(files)

# Step-3: 创建数据迭代器
iterator = dataset.make_one_shot_iterator()

# Step-4: 
x = iterator.get_next()
with tf.Session() as sess:
    for i in range(4):
        print(sess.run(x))
b'test_file_1: This is the first line.'
b'test_file_1: This is the second line.'
b'test_file_2: This is the third line.'
b'test_file_2: This is the fourth line.'

3. 解析TFRecord文件里的数据。读取文件为本章第一节创建的文件。

# Step-0:定义TFRecord解析函数
def parser(record):
    features = tf.parse_single_example(
    record,
    features={
        'image_raw':tf.FixedLenFeature([],tf.string),
        'label':tf.FixedLenFeature([],tf.int64)
    })
    decoded_image = tf.decode_raw(features['image_raw'],tf.uint8)
    retyped_image = tf.cast(decoded_image,tf.float32)
    image = tf.reshape(retyped_image,[784])
    label = tf.cast(features['label'],tf.int32)
    return image, label

# Step-1: 创建数据集Dataset
files = ['output.tfrecords']
dataset =  tf.data.TFRecordDataset(files)

# Step-2: map()函数解析
dataset = dataset.map(parser)

# Step-3: 创建迭代器
iterator = dataset.make_one_shot_iterator()

# Step-4: 创建遍历器
image, label  = iterator.get_next()

# Step-5: 数据处理
fig = plt.figure(figsize=(10,5))
with tf.Session() as sess:
    for i in range(10):
        im,la = sess.run([image,label]) # i,l = sess.run(image,label) 错误,需要加[ ]!
        ax = fig.add_subplot(2,5,i+1)
        ax.imshow(np.reshape(im,(28,28)))
        ax.set_axis_off()
        ax.set_title('Number %d'%la)

@数据集|center|300

4. 使用initializable_iterator来动态初始化数据集。

# 创建数据流图,先指定placeholder(),存放数据,稍后再提供具体路径
input_files = tf.placeholder(tf.string)

# 创建数据集dataset并进行数据解析
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parser)

# 创建迭代器与遍历器
iterator = dataset.make_initializable_iterator()
image, label = iterator.get_next()

n = 0
with tf.Session() as sess:
    # 迭代器初始化,并给出文件input_files的数值
    sess.run(iterator.initializer,feed_dict={input_files:['output.tfrecords']}) # 别忘了方括号

    while True:
        try:
            x,y = sess.run([image,label])
            n=n+1
        except tf.errors.OutOfRangeError:
            break
    print('The total sample number is %d.'%n)
The total sample number is 55000.

猜你喜欢

转载自blog.csdn.net/qq_33039859/article/details/80289918