一、Tensorflow读取.csv
代码报错
机器学习编程中,数据的输入部分一直是困扰博主的一个问题,博主在前期的学习中一直使用的是mnist手写数字项目,这个项目自带数据输入部分的代码。说句题外话,mnist这个项目很适合新手进行深度神经网络和卷积神经网络的编程入门,因为你不用考虑如何输入数据,你只需要改写核心的训练代码。
但是数据的输入和预处理部分是机器学习爱好者迟早要面对的部分,如果你参加一次机器学习大赛,你就会发现,读取.csv
文件的方法很有必要掌握
什么是.csv
文件
简单看一下概念:逗号分隔值(Comma-Separated Values,CSV,有时也称为字符分隔值,因为分隔字符也可以不是逗号),其文件以纯文本形式存储表格数据(数字和文本)。
这个我们不纠结,我直接上官方文档给的代码(根据我创建的.csv
文件稍作了调整):
import tensorflow as tf
filename_queue = tf.train.string_input_producer(["file3.csv", "file4.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1]]
col1, col2, col3, col4 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.concat(0, [col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
#for i in range(1200):
# Retrieve a single instance:
#example, label = sess.run([features])
print(sess.run([features]))
coord.request_stop()
coord.join(threads)
我创建的.csv文件是4列整数,所以上述代码做了微调,请看我的.csv文件内容:
file3.csv
1,2,3,4
2,3,43,5
4,5,56,6
3,4,5,4
file4.csv
1,2,3,4
2,3,4,45
2,3,3,4
3,45,5,3
两个文件都是每列有4个值,由于.csv
文件的概念提到了逗号分隔值,我就在每行使用逗号进行值的分隔。实际上,我尝试用空格进行分隔,但代码运行报错了。
运行结果如下
/home/umbrella/.pyenv/versions/2.7.15/bin/python /home/umbrella/桌面/study_readdata1/day01/readcsv/readcsv1.py
Traceback (most recent call last):
File "/home/umbrella/桌面/study_readdata1/day01/readcsv/readcsv1.py", line 13, in <module>
features = tf.concat(0, [col1, col2, col3, col4])
File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 1187, in concat
tensor_shape.scalar())
File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/framework/tensor_shape.py", line 844, in assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (4,) and () are incompatible
Process finished with exit code 1
看了眼表,离超市关门还有半小时,如果解决不了,按照我不解决问题不罢休的脾气,水就只能在围合贩卖机买了
二、正确的代码
博主成功买到了超市的水
tf.concat
这个函数原型在网上找到了两个版本:
错误的版本(可能是因为过时了)
tf.concat(concat_dim, values, name='concat')
正确的版本
tf.concat(values,concat_dim,name='concat')
concat_dim
表示你在哪个维度上进行连接,从0开始计数,0表示第一个维度,1表示第二个维度…
官方文档给的代码是前者,concat_dim
取值为0,0作为concat
的第一个参数,但是这样是错的!
将代码进行修改,得到如下代码:
import tensorflow as tf
filename_queue = tf.train.string_input_producer(["file3.csv", "file4.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [[1], [1], [1], [1]]
col1, col2, col3, col4 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.concat(0, [[col1], [col2], [col3], [col4]])
#features = tf.stack([col1, col2, col3, col4]) #把上一行注释掉用这一行也可以
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
print(sess.run([features]))
# for i in range(9):
# print(sess.run([features]))
coord.request_stop()
coord.join(threads)
运行结果如下
/home/umbrella/.pyenv/versions/2.7.15/bin/python /home/umbrella/桌面/study_readdata1/day01/readcsv/readcsv2.py
2018-07-18 00:29:40.698139: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
[array([1, 2, 3, 4], dtype=int32)]
Process finished with exit code 0
可以看到file3.csv
的第一行被打印了出来,如果想打印所有数据可以用最后几行的两行注释替换print...
那一行,结果如下:
/home/umbrella/.pyenv/versions/2.7.15/bin/python /home/umbrella/桌面/study_readdata1/day01/readcsv/readcsv2.py
2018-07-18 00:33:10.401338: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
[array([1, 2, 3, 4], dtype=int32)]
[array([ 2, 3, 43, 5], dtype=int32)]
[array([ 4, 5, 56, 6], dtype=int32)]
[array([3, 4, 5, 4], dtype=int32)]
[array([1, 2, 3, 4], dtype=int32)]
[array([ 2, 3, 4, 45], dtype=int32)]
[array([2, 3, 3, 4], dtype=int32)]
[array([ 3, 45, 5, 3], dtype=int32)]
2018-07-18 00:33:10.419660: W tensorflow/core/kernels/queue_base.cc:277] _0_input_producer: Skipping cancelled enqueue attempt with queue not closed
Traceback (most recent call last):
File "/home/umbrella/桌面/study_readdata1/day01/readcsv/readcsv2.py", line 26, in <module>
print(sess.run([features]))
File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1135, in _run
feed_dict_tensor, options, run_metadata)
File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
run_metadata)
File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: <exception str() failed>
Process finished with exit code 1
可以看到,文件中一共八行内容都被打印出来了,但最后有点报错,有时运行只会打印出file3.csv
的内容,尚未解决(可参考TensorFlow中的队列 : https://blog.csdn.net/huachao1001/article/details/78083125 )。