padding代码

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/pku_langzi/article/details/81239291
def pad_2d_vals(in_vals, dim1_size, dim2_size, dtype=np.int32):
    out_val = np.zeros((dim1_size, dim2_size), dtype=dtype)
    if dim1_size > len(in_vals): dim1_size = len(in_vals)
    for i in xrange(dim1_size):
        cur_in_vals = in_vals[i]
        cur_dim2_size = dim2_size
        if cur_dim2_size > len(cur_in_vals): cur_dim2_size = len(cur_in_vals)
        out_val[i,:cur_dim2_size] = cur_in_vals[:cur_dim2_size]
    return out_val

self.in_question_words = pad_2d_vals(self.in_question_words, self.batch_size, self.question_len, dtype=np.int32)

上面代码的意思其实就是对于按照序列最长为question_len的标准,不够就补0。

in_question_word_repres = tf.nn.embedding_lookup(self.word_embedding, self.in_question_words) # [batch_size, question_len, word_dim]

question_mask = tf.sequence_mask(self.question_lengths, question_len, dtype=tf.float32) # [batch_size, question_len]

in_question_repres = tf.multiply(in_question_repres, tf.expand_dims(question_mask, axis=-1))

猜你喜欢

转载自blog.csdn.net/pku_langzi/article/details/81239291