一、Constants
最简单的方法。
import tensorflow as tf
import numpy as np
actual_data = np.random.normal(size=[100])
data = tf.constant(actual_data)
这种方式效率很高,但是当需要应用于其他数据时,必须重写。并且这种方式一次性将所有数据加载到内存,仅适用小数据集。
二、Placeholders
import tensorflow as tf
import numpy as np
data = tf.placeholder(tf.float32)
prediction = tf.square(data) + 1
actual_data = np.random.normal(size=[100])
tf.Session().run(prediction, feed_dict={data: actual_data})
placeholders 是在session run中通过feed_dict来feed数据。
三、Python ops
def py_input_fn():
actual_data = np.random.normal(size=[100])
return actual_data
data = tf.py_func(py_input_fn, [], (tf.float32))
四、Dataset API
tensorflow中推荐使用dataset API来feed 数据。
actual_data = np.random.normal(size=[100])
dataset = tf.contrib.data.Dataset.from_tensor_slices(actual_data)
data = dataset.make_one_shot_iterator().get_next()
dataset = dataset.cache()
if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.repeat()
dataset = dataset.shuffle(batch_size * 5)
dataset = dataset.map(parse, num_threads=8)
dataset = dataset.batch(batch_size)
如果是从文件中获取数据,将文件数据转化成TFrecord格式再用TFRecordDataset读取效率更高。
dataset = tf.contrib.data.TFRecordDataset(path_to_data)