# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
from .configuration_baichuan import BaichuanConfig
from .generation_utils import build_chat_input, TextIterStreamer
import math
from threading import Thread
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.activations import ACT2FN
from transformers.generation.utils import GenerationConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging, ContextManagers
import os
from contextlib import contextmanager
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
try:
from xformers import ops as xops
except ImportError:
xops = None
logger.warning(
"Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
)
def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
_get_interleave_power_of_2(closest_power_of_2)
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
def _fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
_future_mask = _future_mask.unsqueeze(0) + alibi
new_future_mask = _future_mask.to(tensor)
return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
def _gen_alibi_mask(tensor, n_head, max_pos):
slopes = torch.Tensor(_get_interleave(n_head))
position_point = torch.arange(max_pos) - max_pos + 1
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
alibi = alibi.view(n_head, 1, max_pos)
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
alibi_mask = alibi_mask.unsqueeze(0) + alibi
return alibi_mask
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, epsilon=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(hidden_size))
self.epsilon = epsilon
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
# convert into half-precision
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class MLP(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class BaichuanAttention(torch.nn.Module):
def __init__(self, config: BaichuanConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.model_max_length
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
)
self.W_pack = torch.nn.Linear(
self.hidden_size, 3 * self.hidden_size, bias=False
)
self.o_proj = torch.nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
proj = self.W_pack(hidden_states)
proj = (
proj.unflatten(-1, (3, self.hidden_size))
.unsqueeze(0)
.transpose(0, -2)
.squeeze(-2)
)
query_states = (
proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
)
key_states = (
proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
)
value_states = (
proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if xops is not None and self.training:
attn_weights = None
# query_states = query_states.transpose(1, 2)
# key_states = key_states.transpose(1, 2)
# value_states = value_states.transpose(1, 2)
# attn_output = xops.memory_efficient_attention(
# query_states, key_states, value_states, attn_bias=attention_mask
# )
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
attn_output = attn_output.transpose(1, 2)
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None:
if q_len == 1: # inference with cache
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
-
from .configuration_baichuan import BaichuanConfig
导入当前包下的
configuration_baichuan
模块中的BaichuanConfig
类。 -
from .generation_utils import build_chat_input, TextIterStreamer
导入当前包下的
generation_utils
模块中的build_chat_input
和TextIterStreamer
。 -
import math
导入Python的内置数学函数库。
-
from threading import Thread
导入Python的多线程库中的
Thread
类。 -
from typing import List, Optional, Tuple, Union
导入Python的类型注释库,这里导入了
List
、Optional
、Tuple
和Union
。 -
import torch
导入PyTorch框架。
-
from torch import nn
从PyTorch中导入神经网络库。
-
from torch.nn import CrossEntropyLoss
从PyTorch的神经网络库中导入交叉熵损失函数。
-
from torch.nn import functional as F
从PyTorch的神经网络库中导入功能模块,并为其取别名
F
。 -
from transformers import PreTrainedModel, PretrainedConfig
从transformers
库中导入 PreTrainedModel
和 PretrainedConfig
。
-
from transformers.activations import ACT2FN
从
transformers
库中导入激活函数的映射表ACT2FN
。 -
from transformers.generation.utils import GenerationConfig
从
transformers
库中的generation.utils
模块导入GenerationConfig
。 -
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
从
transformers
库中的modeling_outputs
模块导入BaseModelOutputWithPast
和CausalLMOutputWithPast
。 -
from transformers.utils import logging, ContextManagers
从
transformers
库中的utils
模块导入logging
和ContextManagers
。 -
import os
导入Python的内置OS模块,用于处理操作系统相关任务。
-
from contextlib import contextmanager
从Python的
contextlib
模块导入contextmanager
,它用于创建上下文管理器。 -
from accelerate import init_empty_weights
从
accelerate
库中导入init_empty_weights
函数。 -
logger = logging.get_logger(__name__)
使用
transformers
提供的日志功能创建一个日志器对象,__name__
是当前模块的名称。
19-26. try
... except ImportError
代码块:
goCopy code
试图从 `xformers` 库中导入 `ops` 模块,并为其取别名`xops`。如果导入失败(即没有正确安装`xformers`库),则将`xops`设为`None`,并发出一个警告消息。
28-39. _get_interleave(n)
:
goCopy code
定义了一个辅助函数`_get_interleave`。这个函数有一个内嵌函数`_get_interleave_power_of_2`,用于计算并返回一个列表,该列表的长度与给定的数字`n`相同,元素为从开始值开始的等比数列。主函数`_get_interleave`根据`n`是否是2的整数次幂来调用内部函数,并返回一个列表。
41-43. _fill_with_neg_inf(t)
:
goCopy code
定义了一个辅助函数`_fill_with_neg_inf`,它接受一个张量`t`,将其填充为负无穷大,并返回该张量。
45-50. _buffered_future_mask(tensor, maxpos, alibi, attn_heads)
:
goCopy code
定义了一个辅助函数`_buffered_future_mask`,该函数用于生成一个未来掩码(常用于Transformer的自注意力机制)。
52-64. _gen_alibi_mask(tensor, n_head, max_pos)
:
goCopy code
定义了一个辅助函数`_gen_alibi_mask`,用于生成一个"alibi"掩码。
66-77. RMSNorm
类:
Copy code
定义了一个层归一化的变体:RMSNorm。这是一个神经网络模块,其核心功能是通过平方的均值进行归一化。
79-93. MLP
类:
goCopy code
定义了一个多层感知机(MLP)类,这是一个神经网络模块,包含三个线性层和一个激活函数。在其前向传播中,输入`x`首先经过`gate_proj`层和激活函数,然后与`up_proj`的输出相乘,最后经过`down_proj`层。
这段代码主要包含了一些辅助函数和两个神经网络模块:RMSNorm
和MLP
。这些功能可能用于更大的Transformer模型或其他神经网络模型中。
Baichuan_attention:
这是一个名为 BaichuanAttention
的类,它定义了一个自注意力机制的模块,与BERT、Transformer等模型中的注意力机制相似。以下是代码的逐行解释:
-
class BaichuanAttention(torch.nn.Module):
定义一个名为BaichuanAttention
的类,它继承了torch.nn.Module
,意味着这是一个PyTorch神经网络模块。 -
def __init__(self, config: BaichuanConfig):
定义构造函数,它接收一个BaichuanConfig
类型的参数。 -
super().__init__()
调用父类的构造函数,这是在PyTorch中定义自己的网络层时的常规操作。 -
self.config = config
将传入的配置保存为类的属性。 -
self.hidden_size = config.hidden_size
从配置中获取hidden_size
属性,并保存为类的属性。 -
self.num_heads = config.num_attention_heads
从配置中获取注意力头数,并保存为类的属性。 -
self.head_dim = self.hidden_size // self.num_heads
计算每个注意力头的维度,并保存为类的属性。 -
self.max_position_embeddings = config.model_max_length
从配置中获取模型的最大长度,并保存为类的属性。
9-12. if (self.head_dim * self.num_heads) != self.hidden_size:
验证隐藏层大小是否可以被注意力头数整除。
13-15. self.W_pack = torch.nn.Linear(...)
定义一个线性层,用于对输入的隐藏状态进行线性变换以获取查询、键和值。
16-18. self.o_proj = torch.nn.Linear(...)
定义一个线性层,用于注意力机制后的输出。
19-22. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
定义一个辅助函数,将给定的张量重新整形以适应注意力机制。
23-32. def forward(...):
定义模块的前向传播函数。
bsz, q_len, _ = hidden_states.size()
获取输入张量的批次大小、序列长度和隐藏层大小。
34-38. 这部分代码将输入隐藏状态通过线性变换得到查询、键和值。
39-48. 这部分代码调整查询、键和值的形状,使其适合注意力计算。
49-54. 如果提供了 past_key_value
,则与当前的键和值连接。
55-56. 如果 use_cache
为 True
,则保存键和值。
57-68. 判断是否安装了 xformers
,并根据条件使用不同的注意力计算方式。
69-81. 如果没有使用 xformers
,则使用常规的缩放点积自注意力计算方式。
82-83. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
调整注意力输出的形状。
attn_output = self.o_proj(attn_output)
将注意力输出通过线性层。
85-87. 根据 output_attentions
的值,可能会返回注意力权重或设置其为 None
。
return attn_output, attn_weights, past_key_value
返回注意力输出、注意力权重和过去的键值对。
这个模块实现了缩放点积自注意力,它是Transformer架构中的关键组件。