Baichuan2源码解析之:Baichuan2-13B-Chat/modelling_baichuan.py

# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.

from .configuration_baichuan import BaichuanConfig
from .generation_utils import build_chat_input, TextIterStreamer

import math
from threading import Thread
from typing import List, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.activations import ACT2FN
from transformers.generation.utils import GenerationConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging, ContextManagers

import os
from contextlib import contextmanager
from accelerate import init_empty_weights

logger = logging.get_logger(__name__)

try:
    from xformers import ops as xops
except ImportError:
    xops = None
    logger.warning(
        "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
    )


def _get_interleave(n):
    def _get_interleave_power_of_2(n):
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return _get_interleave_power_of_2(n)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(n))
        return (
            _get_interleave_power_of_2(closest_power_of_2)
            + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
        )


def _fill_with_neg_inf(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(float("-inf")).type_as(t)


def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
    _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
    _future_mask = _future_mask.unsqueeze(0) + alibi
    new_future_mask = _future_mask.to(tensor)
    return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]


def _gen_alibi_mask(tensor, n_head, max_pos):
    slopes = torch.Tensor(_get_interleave(n_head))
    position_point = torch.arange(max_pos) - max_pos + 1
    position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
    diag = torch.diag(position_point[0])
    position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
    alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
    alibi = alibi.view(n_head, 1, max_pos)
    alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
    alibi_mask = alibi_mask.unsqueeze(0) + alibi
    return alibi_mask


class RMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, epsilon=1e-6):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(hidden_size))
        self.epsilon = epsilon

    def forward(self, hidden_states):
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)

        # convert into half-precision
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

        return self.weight * hidden_states


class MLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
    ):
        super().__init__()
        self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
        self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class BaichuanAttention(torch.nn.Module):
    def __init__(self, config: BaichuanConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.max_position_embeddings = config.model_max_length

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
            )
        self.W_pack = torch.nn.Linear(
            self.hidden_size, 3 * self.hidden_size, bias=False
        )
        self.o_proj = torch.nn.Linear(
            self.num_heads * self.head_dim, self.hidden_size, bias=False
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return (
            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
            .contiguous()
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        proj = self.W_pack(hidden_states)
        proj = (
            proj.unflatten(-1, (3, self.hidden_size))
            .unsqueeze(0)
            .transpose(0, -2)
            .squeeze(-2)
        )
        query_states = (
            proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )
        key_states = (
            proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )
        value_states = (
            proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None
        if xops is not None and self.training:
            attn_weights = None
            # query_states = query_states.transpose(1, 2)
            # key_states = key_states.transpose(1, 2)
            # value_states = value_states.transpose(1, 2)
            # attn_output = xops.memory_efficient_attention(
            #     query_states, key_states, value_states, attn_bias=attention_mask
            # )
            with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
                attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
            attn_output = attn_output.transpose(1, 2)
        else:
            attn_weights = torch.matmul(
                query_states, key_states.transpose(2, 3)
            ) / math.sqrt(self.head_dim)

            if attention_mask is not None:
                if q_len == 1:  # inference with cache
                    if len(attention_mask.size()) == 4:
                        attention_mask = attention_mask[:, :, -1:, :]
                    else:
                        attention_mask = attention_mask[:, -1:, :]
                attn_weights = attn_weights + attention_mask
                attn_weights = torch.max(
                    attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
                )

            attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
            attn_output = torch.matmul(attn_weights, value_states)

            attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value
  1. from .configuration_baichuan import BaichuanConfig

    导入当前包下的 configuration_baichuan 模块中的 BaichuanConfig 类。

  2. from .generation_utils import build_chat_input, TextIterStreamer

    导入当前包下的 generation_utils 模块中的 build_chat_inputTextIterStreamer

  3. import math

    导入Python的内置数学函数库。

  4. from threading import Thread

    导入Python的多线程库中的 Thread 类。

  5. from typing import List, Optional, Tuple, Union

    导入Python的类型注释库,这里导入了 ListOptionalTupleUnion

  6. import torch

    导入PyTorch框架。

  7. from torch import nn

    从PyTorch中导入神经网络库。

  8. from torch.nn import CrossEntropyLoss

    从PyTorch的神经网络库中导入交叉熵损失函数。

  9. from torch.nn import functional as F

    从PyTorch的神经网络库中导入功能模块,并为其取别名F

  10. from transformers import PreTrainedModel, PretrainedConfig

transformers库中导入 PreTrainedModelPretrainedConfig

  1. from transformers.activations import ACT2FN

    transformers库中导入激活函数的映射表 ACT2FN

  2. from transformers.generation.utils import GenerationConfig

    transformers库中的generation.utils模块导入GenerationConfig

  3. from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast

    transformers库中的modeling_outputs模块导入BaseModelOutputWithPastCausalLMOutputWithPast

  4. from transformers.utils import logging, ContextManagers

    transformers库中的utils模块导入loggingContextManagers

  5. import os

    导入Python的内置OS模块,用于处理操作系统相关任务。

  6. from contextlib import contextmanager

    从Python的contextlib模块导入contextmanager,它用于创建上下文管理器。

  7. from accelerate import init_empty_weights

    accelerate库中导入init_empty_weights函数。

  8. logger = logging.get_logger(__name__)

    使用transformers提供的日志功能创建一个日志器对象,__name__是当前模块的名称。

19-26. try ... except ImportError 代码块:

 
 

goCopy code

试图从 `xformers` 库中导入 `ops` 模块,并为其取别名`xops`。如果导入失败(即没有正确安装`xformers`库),则将`xops`设为`None`,并发出一个警告消息。

28-39. _get_interleave(n):

 
 

goCopy code

定义了一个辅助函数`_get_interleave`。这个函数有一个内嵌函数`_get_interleave_power_of_2`,用于计算并返回一个列表,该列表的长度与给定的数字`n`相同,元素为从开始值开始的等比数列。主函数`_get_interleave`根据`n`是否是2的整数次幂来调用内部函数,并返回一个列表。

41-43. _fill_with_neg_inf(t):

 
 

goCopy code

定义了一个辅助函数`_fill_with_neg_inf`,它接受一个张量`t`,将其填充为负无穷大,并返回该张量。

45-50. _buffered_future_mask(tensor, maxpos, alibi, attn_heads):

 
 

goCopy code

定义了一个辅助函数`_buffered_future_mask`,该函数用于生成一个未来掩码(常用于Transformer的自注意力机制)。

52-64. _gen_alibi_mask(tensor, n_head, max_pos):

 
 

goCopy code

定义了一个辅助函数`_gen_alibi_mask`,用于生成一个"alibi"掩码。

66-77. RMSNorm类:

 
 

Copy code

定义了一个层归一化的变体:RMSNorm。这是一个神经网络模块,其核心功能是通过平方的均值进行归一化。

79-93. MLP类:

 
 

goCopy code

定义了一个多层感知机(MLP)类,这是一个神经网络模块,包含三个线性层和一个激活函数。在其前向传播中,输入`x`首先经过`gate_proj`层和激活函数,然后与`up_proj`的输出相乘,最后经过`down_proj`层。

这段代码主要包含了一些辅助函数和两个神经网络模块:RMSNormMLP。这些功能可能用于更大的Transformer模型或其他神经网络模型中。

Baichuan_attention:
 

这是一个名为 BaichuanAttention 的类,它定义了一个自注意力机制的模块,与BERT、Transformer等模型中的注意力机制相似。以下是代码的逐行解释:

  1. class BaichuanAttention(torch.nn.Module): 定义一个名为 BaichuanAttention 的类,它继承了 torch.nn.Module,意味着这是一个PyTorch神经网络模块。

  2. def __init__(self, config: BaichuanConfig): 定义构造函数,它接收一个 BaichuanConfig 类型的参数。

  3. super().__init__() 调用父类的构造函数,这是在PyTorch中定义自己的网络层时的常规操作。

  4. self.config = config 将传入的配置保存为类的属性。

  5. self.hidden_size = config.hidden_size 从配置中获取 hidden_size 属性,并保存为类的属性。

  6. self.num_heads = config.num_attention_heads 从配置中获取注意力头数,并保存为类的属性。

  7. self.head_dim = self.hidden_size // self.num_heads 计算每个注意力头的维度,并保存为类的属性。

  8. self.max_position_embeddings = config.model_max_length 从配置中获取模型的最大长度,并保存为类的属性。

9-12. if (self.head_dim * self.num_heads) != self.hidden_size: 验证隐藏层大小是否可以被注意力头数整除。

13-15. self.W_pack = torch.nn.Linear(...) 定义一个线性层,用于对输入的隐藏状态进行线性变换以获取查询、键和值。

16-18. self.o_proj = torch.nn.Linear(...) 定义一个线性层,用于注意力机制后的输出。

19-22. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 定义一个辅助函数,将给定的张量重新整形以适应注意力机制。

23-32. def forward(...): 定义模块的前向传播函数。

  1. bsz, q_len, _ = hidden_states.size() 获取输入张量的批次大小、序列长度和隐藏层大小。

34-38. 这部分代码将输入隐藏状态通过线性变换得到查询、键和值。

39-48. 这部分代码调整查询、键和值的形状,使其适合注意力计算。

49-54. 如果提供了 past_key_value,则与当前的键和值连接。

55-56. 如果 use_cacheTrue,则保存键和值。

57-68. 判断是否安装了 xformers,并根据条件使用不同的注意力计算方式。

69-81. 如果没有使用 xformers,则使用常规的缩放点积自注意力计算方式。

82-83. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 调整注意力输出的形状。

  1. attn_output = self.o_proj(attn_output) 将注意力输出通过线性层。

85-87. 根据 output_attentions 的值,可能会返回注意力权重或设置其为 None

  1. return attn_output, attn_weights, past_key_value 返回注意力输出、注意力权重和过去的键值对。

这个模块实现了缩放点积自注意力,它是Transformer架构中的关键组件

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/133090157