版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_24548569/article/details/81747863
读取CSV文件数据
在tensorflow中读取CSV文件需要用到TextLineReader
和decode_csv
。
首先准备两个csv文件:
file1.csv 内容:
,2,3,4,11
1,,3,4,12
1,2,,4,13
1,2,3,,14
file2.csv 内容:
,2,3,4,21
1,,3,4,22
1,2,,4,23
1,2,3,,24
第5列(最后一列)作为样本的标签,1开头表示file1的样本,2开头表示file2的样本。csv文件中有空列。
读取csv文件代码:
import tensorflow as tf
# 生成文件名字符串张量列表,shuffle打乱文件名列表
filename_queue = tf.train.string_input_producer(["file1.csv", "file2.csv"], shuffle=True)
# 使用TextLineReader阅读器读取文件
reader = tf.TextLineReader()
# read方法每次读取一行数据,key表示数据所在的文件,value为读取的一行数据
key, value = reader.read(filename_queue)
# 设置默认数据,当读到空数据时使用默认数据
# 负号表示默认数据,数字表示列序号
record_defaults = [[-1], [-2], [-3], [-4], [0]]
# 解析csv数据
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
# 组成特征向量
features = tf.stack([col1, col2, col3, col4])
with tf.Session() as sess:
# 将文件名填充到队列
# 在调用 run 或 eval 执行 read 之前,必须调用 start_queue_runners 。否则 read 操作会被阻塞到文件名队列中有值为止。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# 读取数据
for i in range(50):
key_run, value_run, example, label = sess.run([key, value, features, col5])
print("step %02d file: %s raw_data: %s" % (i, key_run, value_run), end=" ")
print("features: ", example, "label: ", label)
coord.request_stop()
coord.join(threads)
两个csv文件8个样本,这里读取50个样本,说明可以重复读取样本。
输出结果如下:
step 00 file: b'file2.csv:1' raw_data: b',2,3,4,21' features: [-1 2 3 4] label: 21
step 01 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features: [ 1 -2 3 4] label: 22
step 02 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features: [ 1 2 -3 4] label: 23
step 03 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features: [ 1 2 3 -4] label: 24
step 04 file: b'file1.csv:1' raw_data: b',2,3,4,11' features: [-1 2 3 4] label: 11
step 05 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features: [ 1 -2 3 4] label: 12
step 06 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features: [ 1 2 -3 4] label: 13
step 07 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features: [ 1 2 3 -4] label: 14
step 08 file: b'file1.csv:1' raw_data: b',2,3,4,11' features: [-1 2 3 4] label: 11
step 09 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features: [ 1 -2 3 4] label: 12
step 10 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features: [ 1 2 -3 4] label: 13
step 11 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features: [ 1 2 3 -4] label: 14
step 12 file: b'file2.csv:1' raw_data: b',2,3,4,21' features: [-1 2 3 4] label: 21
step 13 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features: [ 1 -2 3 4] label: 22
step 14 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features: [ 1 2 -3 4] label: 23
step 15 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features: [ 1 2 3 -4] label: 24
step 16 file: b'file1.csv:1' raw_data: b',2,3,4,11' features: [-1 2 3 4] label: 11
step 17 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features: [ 1 -2 3 4] label: 12
step 18 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features: [ 1 2 -3 4] label: 13
step 19 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features: [ 1 2 3 -4] label: 14
step 20 file: b'file2.csv:1' raw_data: b',2,3,4,21' features: [-1 2 3 4] label: 21
step 21 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features: [ 1 -2 3 4] label: 22
step 22 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features: [ 1 2 -3 4] label: 23
step 23 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features: [ 1 2 3 -4] label: 24
step 24 file: b'file2.csv:1' raw_data: b',2,3,4,21' features: [-1 2 3 4] label: 21
step 25 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features: [ 1 -2 3 4] label: 22
step 26 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features: [ 1 2 -3 4] label: 23
step 27 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features: [ 1 2 3 -4] label: 24
step 28 file: b'file1.csv:1' raw_data: b',2,3,4,11' features: [-1 2 3 4] label: 11
step 29 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features: [ 1 -2 3 4] label: 12
step 30 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features: [ 1 2 -3 4] label: 13
step 31 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features: [ 1 2 3 -4] label: 14
step 32 file: b'file1.csv:1' raw_data: b',2,3,4,11' features: [-1 2 3 4] label: 11
step 33 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features: [ 1 -2 3 4] label: 12
step 34 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features: [ 1 2 -3 4] label: 13
step 35 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features: [ 1 2 3 -4] label: 14
step 36 file: b'file2.csv:1' raw_data: b',2,3,4,21' features: [-1 2 3 4] label: 21
step 37 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features: [ 1 -2 3 4] label: 22
step 38 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features: [ 1 2 -3 4] label: 23
step 39 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features: [ 1 2 3 -4] label: 24
step 40 file: b'file2.csv:1' raw_data: b',2,3,4,21' features: [-1 2 3 4] label: 21
step 41 file: b'file2.csv:2' raw_data: b'1,,3,4,22' features: [ 1 -2 3 4] label: 22
step 42 file: b'file2.csv:3' raw_data: b'1,2,,4,23' features: [ 1 2 -3 4] label: 23
step 43 file: b'file2.csv:4' raw_data: b'1,2,3,,24' features: [ 1 2 3 -4] label: 24
step 44 file: b'file1.csv:1' raw_data: b',2,3,4,11' features: [-1 2 3 4] label: 11
step 45 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features: [ 1 -2 3 4] label: 12
step 46 file: b'file1.csv:3' raw_data: b'1,2,,4,13' features: [ 1 2 -3 4] label: 13
step 47 file: b'file1.csv:4' raw_data: b'1,2,3,,14' features: [ 1 2 3 -4] label: 14
step 48 file: b'file1.csv:1' raw_data: b',2,3,4,11' features: [-1 2 3 4] label: 11
step 49 file: b'file1.csv:2' raw_data: b'1,,3,4,12' features: [ 1 -2 3 4] label: 12
读取二进制文件的固定长度的记录
从二进制文件中读取固定长度记录,需要 FixedLengthRecordReader
和 decode_raw
。
这里使用的二进制文件是 CIFAR-10 数据集。文件格式是:每条记录的长度都是固定的,一个字节的标签,后面是3072( )字节的图像数据。下载 CIFAR-10 数据集文件(CSDN不能设置免费,可以从 CIFAR-10 官网下载)
具体代码如下:
import tensorflow as tf
from PIL import Image
# 生成文件名字符串张量列表
data_dir = "cifar10_data/cifar-10-batches-bin"
filename_queue = [data_dir + ("/data_batch_%d.bin" % i) for i in range(1, 6)]
filename_queue = tf.train.string_input_producer(filename_queue)
label_bytes = 1
height = 32
width = 32
depth = 3
image_bytes = height * width * depth
# 每条记录长度
record_bytes = label_bytes + image_bytes
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
key, value = reader.read(filename_queue)
# 把byte类型转成uint8类型
record = tf.decode_raw(value, tf.uint8)
# 提取标签,uint8类型转成int32类型
label = tf.cast(tf.strided_slice(record, [0], [label_bytes]), tf.int32)
# 提取图片数据,并 reshape
depth_major = tf.reshape(tf.strided_slice(record, [label_bytes], [record_bytes]), [depth, height, width])
# reshape 为 height × width × depth
uint8image = tf.transpose(depth_major, [1, 2, 0])
with tf.Session() as sess:
# 将文件名填充到队列
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
key_run, label_run, image = sess.run([key, label, uint8image])
print("step %02d file: %s label: %d" % (i, key_run, label_run))
# 显示图片
image = Image.fromarray(image)
image.show()
coord.request_stop()
coord.join(threads)
输出结果:
step 00 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:0' label: 8
step 01 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:1' label: 5
step 02 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:2' label: 0
step 03 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:3' label: 6
step 04 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:4' label: 9
step 05 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:5' label: 2
step 06 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:6' label: 8
step 07 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:7' label: 3
step 08 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:8' label: 6
step 09 file: b'cifar10_data/cifar-10-batches-bin/data_batch_3.bin:9' label: 2