TensorFlow 版本:1.10.0 > Guide > Introduction to Estimators
Estimator 概述
本篇将介绍 TensorFlow 中的 Estimators(可极大简化机器学习编程过程)。Estimators 中封装了以下几部分:
- 训练(training)
- 评估(evaluation)
- 预测(prediction)
- 输出模型(export for serving)
我们既可以使用内置 Estimator,也可以编写自定义 Estimator。
注意:TensorFlow 中的
tf.contrib.learn.Estimator
已经弃用了,请不要使用该 API。
文章目录
1. Estimator 的优势
Estimator 有以下优势:
- 对分布式的良好支持(不需要更改代码)。
- 有利于模型开发者之间的代码分享。
- 简化了模型的创建工作。
- Estimator 建立在 `tf.layers` 上,这简化了自定义 Estimator 的编写。
- Estimator 会为你创建 graph。
Estimator 提供了一个安全的分布式训练环境,其会帮我们控制这么、何时去:
- 建立 graph。
- 初始化 variables。
- 开始 queues。
- 处理 exceptions。
- 创建 checkpoint 文件,从失败中恢复训练。
- 保存 summaries for TensorBoard。
当用 Estimator 编写一个 application,你必须将 input pipeline 和 model 分开。这种分离简化了在不同数据集上的 experiments。
2. 内置的 Estimator
内置的 Estimator 使得你可以在更高层面思考问题。内置的 Estimator 会为你创建、管理 Graph 和 Session 对象。另外,内置的 Estimator 使你可以在最小的代码修改量的情况下实验不同的模型结构。
2.1 基于内置 Estimator 的程序的结构
使用内置 Estimator 的 TF 程序一般包含以下四步:
编写一个或多个数据集导入函数。 例如:创建一个函数来导入训练数据集,另一个函数来导入测试数据集。每一个数据集导入函数必须返回两个对象:
- 一个字典。字典的键名为特征的名字,键值为 表示特征数据的 Tensor 或 Sparse Tensor。
- 一个 Tensor。该 Tensor 包含一个或多个 label。
例如,下面的代码说明了 input function 的基本框架:
def input_fn(dataset): ... # manipulate dataset, extracting the feature dict and the label return feature_dict, label
(更多细节详见 Importing Data。)
定义 feature columns。每一个
tf.feature_column
定义了一个 feature 的名字、类型、预处理。例如,下面的代码片段定义了三个 feature columns 来 Hold 整数、浮点 类型的数据。前两个 feature columns 只是简单的定义了特征的名字和类型。第三个 feature column 还使用了一个 lambda 函数来缩放原始数据:# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column('median_education', normalizer_fn=lambda x: x - global_education_mean)
实例化相应的内置 Estimator。 例如,实例化一个
LinearClassifier
:# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier( feature_columns=[population, crime_rate, median_education], )
调用一个 training、evaluation、inference 方法。例如,所有的 Estimator 提供了一个
train
方法来训练一个模型。# my_training_set is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)
2.2 内置 Estimator 的好处
内置 Estimator 的代码实现是最好的,另外其提供了以下好处:
- 对于分布式的支持是最好的。
- 保存 summaries 的保存策略是最好的。
如果你不使用内置 Estimator,你必须自己编写自定义函数来实现上面的特性。
3. 编写自定义 Estimator
Estimator 的核心是 model function,该函数负责 training、evaluation、prediction 计算图的建立。我们在另一个文档介绍了自定义 Estimator 的方法。
4. 推荐的工作流程
我们推荐的工作流程如下:
- 如果有合适的内置 Estimator,用它来作为你的第一个模型,将其结果作为一个基线。
- 基于内置 Estimator 建立、测试整个输入 pipeline(数据的完整性、可靠性)。
- 如果有更好的内置 Estimator 可用,可以根据实验的结果来确定哪个内置 Estimator 更合适。
- 可能的话,通过编写自定义 Estimator 来进一步提高模型性能。
5. 从 Keras 模型创建 Estimator
Keras 模型可以转换为 Estimator。这使得你的 Keras 模型能够利用 Estimator 的优势(比如:分布式训练)。转换通过 tf.keras.estimator.model_to_estimator
完成。
# Instantiate a Keras inception v3 model.
keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
# Compile model with the optimizer, loss, and metrics you'd like to train with.
keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy',
metric='accuracy')
# Create an Estimator from the compiled Keras model. Note the initial model
# state of the keras model is preserved in the created Estimator.
est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3)
# Treat the derived Estimator as you would with any other Estimator.
# First, recover the input name(s) of Keras model, so we can use them as the
# feature column name(s) of the Estimator input function:
keras_inception_v3.input_names # print out: ['input_1'] >>>>>>> Very Important <<<<<<<
# Once we have the input name(s), we can create the input function, for example,
# for input(s) in the format of numpy ndarray:
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"input_1": train_data},
y=train_labels,
num_epochs=1,
shuffle=False)
# To train, we call Estimator's train function:
est_inception_v3.train(input_fn=train_input_fn, steps=2000)
注意:由 keras 转换来的 estimator 的 feature columns 和 labels 的 name 要与 keras 模型输入输出的 name 保持一致。