tensorflow 利用placeholder选择每个batch里的sub-tensor 实例

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/guotong1988/article/details/84335912

只能reshape成-1,然后gather的时候累加batch_size去取

import tensorflow as tf

def gather_indexes_2d(sequence_tensor, positions):
  sequence_shape = sequence_tensor.shape.as_list()
  batch_size = sequence_shape[0]
  seq_length = sequence_shape[1]

  flat_offsets = tf.reshape(
      tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1])
  flat_positions = tf.reshape(positions + flat_offsets, [-1])
  flat_sequence_tensor = tf.reshape(sequence_tensor,
                                    [batch_size * seq_length])
  output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
  return output_tensor

value = [[0,1],[2,3],[4,5]]
init = tf.constant_initializer(value)
v = tf.get_variable('value', shape=[3,2], initializer=init,dtype=tf.int32)

p = tf.placeholder(shape=[3], dtype=tf.int32)

v_ = gather_indexes_2d(v,p)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print(sess.run(v_,feed_dict={p:[1,1,0]}))

打印结果[1 3 4]

rank3的情况:

import tensorflow as tf

def gather_indexes_3d(sequence_tensor, positions):
  sequence_shape = sequence_tensor.shape.as_list()
  batch_size = sequence_shape[0]
  seq_length = sequence_shape[1]
  width = sequence_shape[2]

  flat_offsets = tf.reshape(
      tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1])
  flat_positions = tf.reshape(positions + flat_offsets, [-1])
  flat_sequence_tensor = tf.reshape(sequence_tensor,
                                    [batch_size * seq_length, width])
  output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
  return output_tensor


v = tf.constant([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]]]) # [2,3,2]
p = tf.placeholder(shape=[2], dtype=tf.int32)

v_ = gather_indexes_3d(v,p)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print(sess.run(v_,feed_dict={p:[1,0]}))

打印结果
[[2 2]
[4 4]]

猜你喜欢

转载自blog.csdn.net/guotong1988/article/details/84335912