版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/pku_langzi/article/details/81239291
def pad_2d_vals(in_vals, dim1_size, dim2_size, dtype=np.int32):
out_val = np.zeros((dim1_size, dim2_size), dtype=dtype)
if dim1_size > len(in_vals): dim1_size = len(in_vals)
for i in xrange(dim1_size):
cur_in_vals = in_vals[i]
cur_dim2_size = dim2_size
if cur_dim2_size > len(cur_in_vals): cur_dim2_size = len(cur_in_vals)
out_val[i,:cur_dim2_size] = cur_in_vals[:cur_dim2_size]
return out_val
self.in_question_words = pad_2d_vals(self.in_question_words, self.batch_size, self.question_len, dtype=np.int32)
上面代码的意思其实就是对于按照序列最长为question_len的标准,不够就补0。
in_question_word_repres = tf.nn.embedding_lookup(self.word_embedding, self.in_question_words) # [batch_size, question_len, word_dim]
question_mask = tf.sequence_mask(self.question_lengths, question_len, dtype=tf.float32) # [batch_size, question_len]
in_question_repres = tf.multiply(in_question_repres, tf.expand_dims(question_mask, axis=-1))