pytorch处理padding变长后的RNN输入

为什么要处理变长输入?

一般的,在通过embedding层转换为词向量之前,我们的输入形式如下:batch_size * max_len,每一个句子都是一个列表,其中的元素是单词对应的下标。
如果一个句子原来的长度<max_len,我们需要进行padding操作,即在这个列表里填0,0下标表示< pad >这一个单词,它的词向量表示一般是随机初始化的。
在这里插入图片描述在这里,max_len就是6,那么对于第二个句子而言,它需要padding一次,…第五个句子需要padding五次。

以第五个句子"Yes"为例,我们把它放入一个序列模型(RNN\LSTM),我们只需要第一个时间步,也就是处理Yes这个单词时,得到的隐藏层输出:
在这里插入图片描述但是如果我们不做变长的操作:我们只能得到处理最后一个pad时的输出,这样会影响性能。
在这里插入图片描述

pytorch如何处理变长?

主要涉及两个函数:

  1. torch.nn.utils.rnn.pack_padded_sequence()
  2. torch.nn.utils.rnn.pad_packed_sequence()

第一个函数用于输入模型前的压缩
第二个函数用于模型已经输出后,对之前的压缩进行恢复(把删掉的padding过的0给它补回来)

具体如何使用?

先讨论第一个函数,它比较难用。使用的例子如下:

packed_batch= torch.nn.utils.rnn.pack_padded_sequence(sorted_batch, sorted_seq_lens,batch_first=True)

直接讲输入和输出吧
输入如下:

  • sorted_batch(batch_size x max_len x word_dim):
    它就是经过embedding之后的词向量矩阵,只不过它已经按照,当前batch下所有句子的长度,从长到短排过序了,也就是刚刚上面这张图(巧了)
    在这里插入图片描述

  • sorted_seq_lens(batch_size * 1):
    它记录了排过序的所有句子的长度(按上图来看就是:6,5,4,3,2,1)

  • batch_first:
    等于True了以后,就会按照batch_size x max_len x word_dim的规格处理输入,否则默认max_len在前。

输出:
packed_batch:压缩过的词向量矩阵,它的类型是:torch.nn.utils.rnn.PackedSequence

OK,我们已经得到了packed_batch,然后将它放入我们的模型(RNN\LSTM\GRU…)

encoder_outputs_packed, (h_last, c_last) = self.lstm(packed_batch)

一般呢,我们进一步要用到的是最后一个时间步的隐藏层输出,也就是h_last,现在我们已经达成了目的:对于没有padding过的句子,我们得到的就是最后一个时间步的输出,对于padding过的句子,我们得到的是最后那个单词对应的输出。
但如果我们要使用encoder_outputs_packed这个输出,用作进一步的词向量表示,用此时的encoder_outputs_packed是不行的

需要用到第二个函数

encoder_output, _ = nn.utils.rnn.pad_packed_sequence(encoder_outputs_packed, batch_first=True)

在实验中我发现了一个问题,比如我设定max_len的长度是200,也就是经过padding后所有的句子统一长度为200。当前batch中,如果最长的句子实际长为135,那,经过packed,没什么问题,然后经过model,也没问题。
问题在pad上,pad后的所有句子长度最大就只有135了,也就是说所有<135的句子会重新pad,pad到135,而不是原来的max_len=200。
我在用ESIM模型时,需要输入两个句子,那么经过pad之后,这两个句子的长度可能就不同了,再后来计算attention的时候就发生了错误。看了看原码,原来函数还有一个参数叫做total_length,于是我把它设置成了200,就好了。

这样就完成了处理变长的全过程,了吗?–没有。因为我们之前已经把输入按照其长度排序了,那么标签就对应不上了,有一种做法是:

记得标签要对应

在排序的时候就要记住没有个句子原来的下标,存储在restoration_index里面,然后使用index_select这一tensor自带方法进行反排序:

encoder_output = encoder_output.index_select(0, restoration_index)

或者也可以在排序的时候,就把对应的label也排了,等等。

总结

我们可以使用torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来避免padding对句子表示的影响。同时我们还需要注意排序过后,样本标签不对应的问题。

参考资料:

pytorch中如何处理RNN输入变长序列padding

猜你喜欢

转载自blog.csdn.net/jokerxsy/article/details/107272133