tf.function函数转换
1. 关于tf.function
tf.function的官方含义是“Creates a callable TensorFlow graph from a Python function.”也就是说,tf.function可以从 Python 函数创建可调用的 TensorFlow 图。这意味着什么呢?通俗点来说,我们可以自己编写python函数,然后利用tf.function进行函数转化,形成TensorFlow图结构,对函数进行加速。说白了,tf.function可以优化并加速我们自己编写的python函数。
2. tf.function 的实现
我们以一个激活函数“elu”为例,来看看经过tf.function的优化有什么变化:
# tf.function and auto-graph.
def scaled_elu(z, scale=1.0, alpha=1.0):
# z >= 0 ? scale * z : scale * alpha * tf.nn.elu(z)
is_positive = tf.greater_equal(z, 0.0)
return scale * tf.where(is_positive, z, alpha * tf.nn.elu(z))
print(scaled_elu(tf.constant(-3.)))
print(scaled_elu(tf.constant([-3., -2.5])))
"""
# 经过tf.function优化之后在输出结果
"""
scaled_elu_tf = tf.function(scaled_elu)
print(scaled_elu_tf(tf.constant(-3.)))
print(scaled_elu_tf(tf.constant([-3., -2.5])))
"""
# 将优化之后的scaled_elu_tf 再次转化为python下的函数,比较二者是否一致
"""
print(scaled_elu_tf.python_function is scaled_elu)
输出结果为:
这里我们发现,经过tf.function之后的函数功能上是没有任何变化的,然后我们对比一下二者的运行时间:
%timeit scaled_elu(tf.random.normal((1000, 1000)))
%timeit scaled_elu_tf(tf.random.normal((1000, 1000)))
输出结果为:
从这里可以看出,经过tf.funcion()优化之后的函数,具有更快的运行速度。
3. 关于@tf.function
tf.function除了第二节的的实现方式之外呢,我们还可以这么做:
@tf.function
def scaled_elu(z, scale=1.0, alpha=1.0):
# z >= 0 ? scale * z : scale * alpha * tf.nn.elu(z)
is_positive = tf.greater_equal(z, 0.0)
return scale * tf.where(is_positive, z, alpha * tf.nn.elu(z))
二者是等价的。