nlp中常用DataLoader中的collate_fn,对batch进行整理使其符合bert的输入

train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=default_collate)
此处的collate_fn,是一个函数,会将DataLoader生成的batch进行一次预处理

假设我们有一个Dataset,有input_ids、attention_mask等列:
使用torch创建dataloder时,如果使用默认的collate_fn(default_collate),输出的batch中,input_ids,和token_type_ids,attention_mask都是长度为 sequence_length 的列表(如果input_ids都已经pad到sequence_length),
列表的每个元素是大小为[batch_size]的tensor,
如果input_ids的长度不相等还会报错。

这并不是我们想要的模型的输入格式。
我们希望一个batch中,input_ids的应该是shape=(batch_size,max_seq_length)的tensor

默认情况下:

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate

train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=default_collate)
for batch in train_loader:
    print(batch)
    break
#如下所示,batch_size=3,所以input_ids都是长度为3的tensor组成的列表,这不是我们想要的
{
    
    'input_ids': List[torch.Tensor], 'attention_mask': List[torch.Tensor]}
Tensor的shape为(batch_size)(3)

这样是无法直接把这个batch输入bert的,必须把input_ids,和token_type_ids,attention_mask都转化为大小为 [ batch_size,sequence_length ] 的tensor才能输入bert。
所以,需要定义自己的collate_fn函数,对batch进行整理(使用torch.stack函数对列表进行拼接):

#定义collate_fn,把input_ids,和token_type_ids,attention_mask列转化为tensor
def collate_fn(examples):
    batch = default_collate(examples)
    batch['input_ids'] = torch.stack(batch['input_ids'], dim=1)
    batch['token_type_ids'] = torch.stack(batch['token_type_ids'], dim=1)
    batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=1)
    return batch

train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=collate_fn)
#输出第一个batch
for batch in train_loader:
    print(batch)
    break

结果如下,可以直接输入bert模型

output=model(**batch)
batch
{
    
    'input_ids': torch.Tensor, 'attention_mask': torch.Tensor}
tensor的shape=(batch_size,max_seq_length),符合bert的输入格式

然而,以上方法必须要求所有input_ids长度一致,否则依然会报错。
其实tokenizer类已经提供了更简单的方法 tokenizer.pad()

    def collate_fn_SentenceClassify(features,tokenizer):
        #batch = default_collate(examples) 不再需要这句话
        #将batch中的input_ids和attention_mask等补齐
        batch = tokenizer.pad(
            features,
            padding=True,
            max_length=None,
        )
        return batch

此处,features是List[Dict[str,Tensor]], batch被整理成了Dict[str,Tensor],符合输入格式。
和DataCollatorWithPadding 效果一样

from transformers import DataCollatorWithPadding 

对于更复杂的任务,例如文本生成任务,可以直接调用transformers库中的collate_fn,例如

from transformers import DataCollatorForSeq2Seq

具体用法可以看huggingface官网
更多DataCollator

猜你喜欢

转载自blog.csdn.net/qq_51750957/article/details/128220076