1. tf.assign()函数说明:
def assign(ref, value, validate_shape=None, use_locking=None, name=None):
"""Update 'ref' by assigning 'value' to it.
This operation outputs a Tensor that holds the new value of 'ref' after
the value has been assigned. This makes it easier to chain operations
that need to use the reset value.
Args:
ref: A mutable `Tensor`.
Should be from a `Variable` node. May be uninitialized.
value: A `Tensor`. Must have the same type as `ref`.
The value to be assigned to the variable.
validate_shape: An optional `bool`. Defaults to `True`.
If true, the operation will validate that the shape
of 'value' matches the shape of the Tensor being assigned to. If false,
'ref' will take on the shape of 'value'.
use_locking: An optional `bool`. Defaults to `True`.
If True, the assignment will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
name: A name for the operation (optional).
Returns:
A `Tensor` that will hold the new value of 'ref' after
the assignment has completed.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.assign(
ref, value, use_locking=use_locking, name=name,
validate_shape=validate_shape)
return ref.assign(value)
注意:
1. 只有tf.assign()操作完成以后,张量才能拥有new value
2.参数validate_shape默认值为True:old value的shape必须与new value的shape相同,否则会报错。如下:
import tensorflow as tf
a = tf.Variable([10, 20])
b = tf.assign(a, [20, 30,1])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("run a : ",sess.run(a))
print("run b : ",sess.run(b))
print("run a again : ",sess.run(a))
#由于tf.assign()的参数validate_shape的默认值True且[20,30,1]与[10,20]的shape不一样
#报错:ValueError: Dimension 0 in both shapes must be equal, but are 2 and 3. Shapes # are [2] and [3]. for 'Assign' (op: 'Assign') with input shapes: [2], [3].
import tensorflow as tf
a = tf.Variable([10, 20])
b = tf.assign(a, [20, 30,1],validate_shape=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("run a : ",sess.run(a))
print(a)
print("run b : ",sess.run(b))
print("run a again : ",sess.run(a))
out:
run a : [10 20]
<tf.Variable 'Variable:0' shape=(2,) dtype=int32_ref>
run b : [20 30 1]
run a again : [20 30 1]
a = tf.Variable([10, 20])
b = tf.assign(a, [20, 30])
c = b + [10, 20]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(a)) # => [10 20]
print(sess.run(c)) # => [30 50] 运行c的时候,由于c中含有b,所以b也被运行了
print(sess.run(a)) # => [20 30]
#tf.assing()未执行,ref不更新新