参考:tensorflow 训练格式TFRecord简单使用
1、tensorflow 训练格式TFRecord简单使用
保存:
import tensorflow as tf
# 回忆上一小节介绍的,每个Example内部实际有若干种Feature表达,下面
# 的四个工具方法方便我们进行Feature的构造
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _int64list_feature(value_list):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))
# Example序列化成字节字符串
def serialize_example(user_id, city_id, app_type, viewd_pois, avg_paid, comment):
# 注意我们需要按照格式来进行数据的组装,这里的dict便按照指定Schema构造了一条Example
feature = {
'user_id': _int64_feature(user_id),
'city_id': _int64_feature(city_id),
'app_type': _int64_feature(app_type),
'viewd_pois': _int64list_feature(viewd_pois),
'avg_paid': _float_feature(avg_paid),
'comment': _bytes_feature(comment),
}
# 调用相关api将Example序列化为字节字符串
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
# 样本的生产,这里展示了将2条样本数据写入到了TFRecord文件中
def write_demo(filepath):
with tf.python_io.TFRecordWriter(filepath) as writer:
writer.write(serialize_example(1, 10, 1, [658, 325], 36.3, "yummy food."))
writer.write(serialize_example(2, 20, 2, [897, 568, 126], 89.6, "nice place to have dinner."))
print ("write demo data done.")
filepath = "testdata.tfrecord"
write_demo(filepath)
读取:
def read_demo(filepath):
# 定义schema
schema = {
'user_id': tf.FixedLenFeature([], tf.int64),
'city_id': tf.FixedLenFeature([], tf.int64),
'app_type': tf.FixedLenFeature([], tf.int64),
'viewed_pois': tf.VarLenFeature(tf.int64),
'avg_paid': tf.FixedLenFeature([], tf.float32, default_value=0.0),
'comment': tf.FixedLenFeature([], tf.string, default_value=''),
}
# 使用相关api,按照schema解析dataset中的样本
def _parse_function(example_proto):
return tf.parse_single_example(example_proto, schema)
# 读取TFRecord文件来创建dataset
dataset = tf.data.TFRecordDataset(filepath)
#按照schema解析dataset中的每个样本
parsed_dataset = dataset.map(_parse_function)
#创建Iterator并迭代Iterator即可访问dataset中的样本
next = parsed_dataset.make_one_shot_iterator().get_next()
# 这里直接利用session,打印dataset中的样本
with tf.Session() as sess:
while True:
try:
print (sess.run(next))
except:
print ("out of data")
break
filepath = "testdata.tfrecord"
read_demo(filepath)
2、spark dataframe保存TFRecord
from pyspark.sql.types import *
def main():
#从hive表中读取数据
df=spark.sql("""
select * from experiment.table""")
#tfrecords保存路径
path = "viewfs:///user/hadoop-hdp/ml/demo/tensorflow/data/tfrecord"
#将spark DataFrame格式数据转换为tfrecords格式数据
df.repartition(file_num).write \
.mode("overwrite") \
.format("tfrecords") \
.option("recordType", "Example")\
.save(path)
if __name__ == "__main__":
main()