ChatGLM Efficient Tuning DataCollator.py源码解析

 src/data_collator.py

from transformers import DataCollatorWithPadding, BatchEncoding
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer

from .other import IGNORE_INDEX


class DataCollatorForChatGLM(DataCollatorWithPadding):
    r"""
    Data collator for ChatGLM. It is capable of dynamically padding for batched data.
    """
    def __init__(
            self,
            tokenizer: PreTrainedTokenizer,
            model: PreTrainedModel,
            ignore_pad_token_for_loss: Optional[bool] = False,
            use_v2: Optional[bool] = False
    ):
        super().__init__(tokenizer, padding=True)
        self.model = model
        self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
        if use_v2:
            self.get_attention_masks = self.get_attention_masks_v2
            self.get_position_ids = self.get_position_ids_v2
        else:
            self.get_attention_masks = self.get_attention_masks_v1
            self.get_position_ids = self.get_position_ids_v1

    def get_attention_masks_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
        r"""
        Generates attention masks for left-padded sequences.

        Note that ChatGLM assigns False on token to be attended in attention mask. In general settings, it should be True.

        According to: https://huggingface.co/THUDM/chatglm-6b/blob/v1.1.0/modeling_chatglm.py#L680
        """
        batch_size, seq_length = input_ids.size()
        attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
        attention_mask.tril_()

        for i, seq in enumerate(input_ids):
            attention_mask[i, :, :(seq == self.tokenizer.bos_token_id).nonzero()[0].item()] = 1 # context
            attention_mask[i, :, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding

        attention_mask.unsqueeze_(1)
        attention_mask = (attention_mask < 0.5).bool()
        return attention_mask

    def get_position_ids_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
        r"""
        Generates position ids for left-padded sequenes.

        According to: https://huggingface.co/THUDM/chatglm-6b/blob/v1.1.0/modeling_chatglm.py#L692
        """
        batch_size, seq_length = input_ids.size()
        mask: int = self.model.config.mask_token_id
        gmask: int = self.model.config.gmask_token_id
        position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)
        block_position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)

        for i, seq in enumerate(input_ids):
            mask_token = gmask if gmask in seq else mask
            context_length = (seq == self.tokenizer.bos_token_id).nonzero()[0].item()
            padding_length = (seq != self.tokenizer.pad_token_id).nonzero()[0].item()
            position_ids[i, padding_length:] = torch.arange(
                seq_length - padding_length,
                dtype=torch.long,
                device=device
            )
            if self.model.position_encoding_2d or (mask_token != gmask): # 2d position encoding or not gMASK
                position_ids[i, context_length:] = (seq == mask_token).nonzero()[0].item() - padding_length # mask position
            block_position_ids[i, context_length:] = torch.arange(
                seq_length - context_length,
                dtype=torch.long,
                device=device
            ) + 1

        if self.model.position_encoding_2d:
            position_ids = torch.stack((position_ids, block_position_ids), dim=1)

        return position_ids

这段代码定义了一个名为 DataCollatorForChatGLM 的类,该类用于为 ChatGLM 对话模型进行数据整理。我们来逐行解释一下这段代码:

  1. from transformers import DataCollatorWithPadding, BatchEncoding 导入了 HuggingFace 的 transformers 库中的 DataCollatorWithPaddingBatchEncodingDataCollatorWithPadding 是用于处理不同长度的序列,将它们填充到相同长度的数据整理器,而 BatchEncoding 是用于表示模型输入的数据结构。

  2. from transformers.modeling_utils import PreTrainedModel 导入了 transformers 库中的 PreTrainedModel,它是所有预训练模型的基类。

  3. from transformers.tokenization_utils import PreTrainedTokenizer 导入了 transformers 库中的 PreTrainedTokenizer,它是所有预训练分词器的基类。

  4. from .other import IGNORE_INDEX 从同一级目录下的 other 模块中导入 IGNORE_INDEX,这是一个特殊的索引值,用于在计算损失时忽略某些标记。

  5. class DataCollatorForChatGLM(DataCollatorWithPadding): 定义了一个新的类 DataCollatorForChatGLM,它继承自 DataCollatorWithPadding

  6. def __init__(self, tokenizer: PreTrainedTokenizer, model: PreTrainedModel, ignore_pad_token_for_loss: Optional[bool] = False, use_v2: Optional[bool] = False): 定义了 DataCollatorForChatGLM 类的初始化函数,接收一个分词器(tokenizer)、一个模型(model)、一个可选的布尔值参数 ignore_pad_token_for_loss 和另一个可选的布尔值参数 use_v2

  7. super().__init__(tokenizer, padding=True) 调用了父类 DataCollatorWithPadding 的初始化函数,将分词器和填充参数传递给它。

  8. self.model = model 将传入的模型赋值给 self.model

  9. self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id 如果 ignore_pad_token_for_lossTrue,则将 IGNORE_INDEX 赋值给 self.label_pad_token_id,否则将分词器的 pad_token_id 赋值给它。

  10. if use_v2: ... else: ... 这部分代码根据 use_v2 的值选择使用 v1 版本还是 v2 版本的注意力掩码和位置标识生成函数。

  11. def get_attention_masks_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor: 定义了 v1 版本的注意力掩码生成函数,接收一个 input_ids 张量和一个设备对象,返回一个注意力掩码张量。

  12. def get_position_ids_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor: 定义了 v1 版本的位置标识生成函数,接收一个 input_ids 张量和一个设备对象,返回一个位置标识张量。

总的来说,这段代码定义了一个为 ChatGLM 对话模型服务的数据整理器。根据输入的批量数据,它可以动态地进行填充操作,以及生成注意力掩码和位置标识。在实际使用时,DataCollatorForChatGLM 对象会被传递给数据加载器(data loader),用于生成模型的输入数据。

 def get_attention_masks_v2(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
        r"""
        Generates attention masks for left-padded sequences.
        """
        batch_size, seq_length = input_ids.size()
        attention_mask = torch.ones((batch_size, seq_length), device=device)

        for i, seq in enumerate(input_ids):
            attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding

        return attention_mask

    def get_position_ids_v2(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
        r"""
        Generates position ids for left-padded sequenes.
        """
        batch_size, seq_length = input_ids.size()
        position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)

        for i, seq in enumerate(input_ids):
            padding_length = (seq != self.tokenizer.pad_token_id).nonzero()[0].item()
            position_ids[i, padding_length:] = torch.arange(seq_length - padding_length, dtype=torch.long, device=device)

        return position_ids

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/131483002