在TensorFlow框架中,提供了函数tf.py_func实现自动将Tensor对象转为Python对象,并作为Python函数的形参传入,同时Python函数返回的结果又自动转为Tensor对象返回。也就是说,tf.py_func函数能执行指定的Python函数,并自动将Tensor对象转为Python对象,再将返回的Python对象转为Tensor对象, tf.py_func函数原型如下:
tf.py_func(
func,
inp,
Tout,
stateful=True,
name=None
)
各个参数的意义:
1. func:Python函数类型,指定要执行的函数。
2. lnp:list类型,list里面存放的是Tensor对象,用于传入func函数作为形参。
3. Tout:list类型或是单个对象,存放的是TensorFlow的数据类型,用于描述func函数返回数据转为Tensor对象后的数据类型。
4. stateful:bool类型,默认为True。如果设置为True, 则该函数被认为是与状态有关的。如果函数与状态无关,则相同的输入会产生相同的输出。
5. name:当前Operation的名称。
简单示例:
import tensorflow as tf
def my_add_func(A, B):
# 查看传入的参数的数据类型
print('type(A) =', type(A))
print('type(B) =', type(B))
C = A + B
return C
# 定义Tensor对象
A_tf = tf.constant([[1, 1], [1, 1]], dtype=tf.int64)
B_tf = tf.constant([[2, 2], [2, 2]], dtype=tf.int64)
C_tf = tf.py_func(my_add_func, [A_tf, B_tf], tf.int64)
with tf.Session() as sess:
C = sess.run(C_tf)
print("C =\n", C)
print("type(C_tf) =", type(C_tf))
print("type(C) =", type(C))
输出:
type(A) = <class 'numpy.ndarray'>
type(B) = <class 'numpy.ndarray'>
C =
[[3 3]
[3 3]]
type(C_tf) = <class 'tensorflow.python.framework.ops.Tensor'>
type(C) = <class 'numpy.ndarray'>