src/data_collator.py
from transformers import DataCollatorWithPadding, BatchEncoding
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from .other import IGNORE_INDEX
class DataCollatorForChatGLM(DataCollatorWithPadding):
r"""
Data collator for ChatGLM. It is capable of dynamically padding for batched data.
"""
def __init__(
self,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
ignore_pad_token_for_loss: Optional[bool] = False,
use_v2: Optional[bool] = False
):
super().__init__(tokenizer, padding=True)
self.model = model
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
if use_v2:
self.get_attention_masks = self.get_attention_masks_v2
self.get_position_ids = self.get_position_ids_v2
else:
self.get_attention_masks = self.get_attention_masks_v1
self.get_position_ids = self.get_position_ids_v1
def get_attention_masks_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
r"""
Generates attention masks for left-padded sequences.
Note that ChatGLM assigns False on token to be attended in attention mask. In general settings, it should be True.
According to: https://huggingface.co/THUDM/chatglm-6b/blob/v1.1.0/modeling_chatglm.py#L680
"""
batch_size, seq_length = input_ids.size()
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
attention_mask.tril_()
for i, seq in enumerate(input_ids):
attention_mask[i, :, :(seq == self.tokenizer.bos_token_id).nonzero()[0].item()] = 1 # context
attention_mask[i, :, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
attention_mask.unsqueeze_(1)
attention_mask = (attention_mask < 0.5).bool()
return attention_mask
def get_position_ids_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
r"""
Generates position ids for left-padded sequenes.
According to: https://huggingface.co/THUDM/chatglm-6b/blob/v1.1.0/modeling_chatglm.py#L692
"""
batch_size, seq_length = input_ids.size()
mask: int = self.model.config.mask_token_id
gmask: int = self.model.config.gmask_token_id
position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)
block_position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)
for i, seq in enumerate(input_ids):
mask_token = gmask if gmask in seq else mask
context_length = (seq == self.tokenizer.bos_token_id).nonzero()[0].item()
padding_length = (seq != self.tokenizer.pad_token_id).nonzero()[0].item()
position_ids[i, padding_length:] = torch.arange(
seq_length - padding_length,
dtype=torch.long,
device=device
)
if self.model.position_encoding_2d or (mask_token != gmask): # 2d position encoding or not gMASK
position_ids[i, context_length:] = (seq == mask_token).nonzero()[0].item() - padding_length # mask position
block_position_ids[i, context_length:] = torch.arange(
seq_length - context_length,
dtype=torch.long,
device=device
) + 1
if self.model.position_encoding_2d:
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
return position_ids
这段代码定义了一个名为 DataCollatorForChatGLM
的类,该类用于为 ChatGLM 对话模型进行数据整理。我们来逐行解释一下这段代码:
-
from transformers import DataCollatorWithPadding, BatchEncoding
导入了 HuggingFace 的transformers
库中的DataCollatorWithPadding
和BatchEncoding
。DataCollatorWithPadding
是用于处理不同长度的序列,将它们填充到相同长度的数据整理器,而BatchEncoding
是用于表示模型输入的数据结构。 -
from transformers.modeling_utils import PreTrainedModel
导入了transformers
库中的PreTrainedModel
,它是所有预训练模型的基类。 -
from transformers.tokenization_utils import PreTrainedTokenizer
导入了transformers
库中的PreTrainedTokenizer
,它是所有预训练分词器的基类。 -
from .other import IGNORE_INDEX
从同一级目录下的other
模块中导入IGNORE_INDEX
,这是一个特殊的索引值,用于在计算损失时忽略某些标记。 -
class DataCollatorForChatGLM(DataCollatorWithPadding):
定义了一个新的类DataCollatorForChatGLM
,它继承自DataCollatorWithPadding
。 -
def __init__(self, tokenizer: PreTrainedTokenizer, model: PreTrainedModel, ignore_pad_token_for_loss: Optional[bool] = False, use_v2: Optional[bool] = False):
定义了DataCollatorForChatGLM
类的初始化函数,接收一个分词器(tokenizer)、一个模型(model)、一个可选的布尔值参数ignore_pad_token_for_loss
和另一个可选的布尔值参数use_v2
。 -
super().__init__(tokenizer, padding=True)
调用了父类DataCollatorWithPadding
的初始化函数,将分词器和填充参数传递给它。 -
self.model = model
将传入的模型赋值给self.model
。 -
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
如果ignore_pad_token_for_loss
为True
,则将IGNORE_INDEX
赋值给self.label_pad_token_id
,否则将分词器的pad_token_id
赋值给它。 -
if use_v2: ... else: ...
这部分代码根据use_v2
的值选择使用 v1 版本还是 v2 版本的注意力掩码和位置标识生成函数。 -
def get_attention_masks_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
定义了 v1 版本的注意力掩码生成函数,接收一个input_ids
张量和一个设备对象,返回一个注意力掩码张量。 -
def get_position_ids_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
定义了 v1 版本的位置标识生成函数,接收一个input_ids
张量和一个设备对象,返回一个位置标识张量。
总的来说,这段代码定义了一个为 ChatGLM 对话模型服务的数据整理器。根据输入的批量数据,它可以动态地进行填充操作,以及生成注意力掩码和位置标识。在实际使用时,DataCollatorForChatGLM
对象会被传递给数据加载器(data loader),用于生成模型的输入数据。
def get_attention_masks_v2(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
r"""
Generates attention masks for left-padded sequences.
"""
batch_size, seq_length = input_ids.size()
attention_mask = torch.ones((batch_size, seq_length), device=device)
for i, seq in enumerate(input_ids):
attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
return attention_mask
def get_position_ids_v2(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
r"""
Generates position ids for left-padded sequenes.
"""
batch_size, seq_length = input_ids.size()
position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)
for i, seq in enumerate(input_ids):
padding_length = (seq != self.tokenizer.pad_token_id).nonzero()[0].item()
position_ids[i, padding_length:] = torch.arange(seq_length - padding_length, dtype=torch.long, device=device)
return position_ids