LLMs之FlashAttention-2:《FlashAttention-2: Faster Attention with Better Parallelism and Work Partition

LLMs之FlashAttention-2:《FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning更快的注意力与更好的并行性和工作分区》翻译与解读

导读:FlashAttention-2通过算法并行计算工作分配的优化,实现了原FlashAttention注意力计算的显著加速,有助于推动更长序列模型和应用的发展。

长文本序列模型的 Attention 计算开销是序列长度的二次方,限制了模型输入序列长度的扩展。原来的FlashAttention算法已经实现了2-4倍速度提升,但效率仍有提升空间,无法达到矩阵乘法性能极限。新的FlashAttention-2算法从根本上重写,利用Nvidia CUTLASS库实现,利用并行计算更全面,工作分配更优化。

>> FlashAttention-2减少非矩阵乘法运算,提升矩阵乘法利用率,GPU资源利用率更高

>>FlashAttention-2在并行度方面增加了序列长度维度的并行计算

>>FlashAttention-2在线程块内部线程群的工作分配上进行了优化,减少线程同步带来的缓存读取开销

>>FlashAttention-2支持头数量最大256,支持多查询注意力机制(如MQA/GQA),应用范围更广。

>>通过Benchmark证实,FlashAttention-2在A100上实现2倍速度提升,最高能达到335TFLOPs/s效率。

>>在训练GPT风格模型上,FlashAttention-2可以实现1.3倍训练速度提升。

未来工作将继续在新硬件和数据类型上优化FlashAttention-2算法。

扫描二维码关注公众号,回复: 16840468 查看本文章

目录

《FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning》翻译与解读

扩展Transformer的上下文长度是一个挑战—需要更长上下文的语言模型:GPT-4(32k)、MPT(65k)、Claude(100k)

2022年发布FlashAttention(速注意力并减少其内存占用):比几线快2~4倍

今天,正式发布FlashAttention-2

1、FlashAttention回顾

FlashAttention的优点:是一种重新排序注意力计算的算法,利经典技术(tiling、重新计算,提速2~4倍)加速+从与序列长度成二次关系降低到与线性关系

FlashAttention的缺点:低占用率(GPU工作分区不够优化)、不必要的共享内存读写

Diagram of FlashAttention forward pass: with tiling and softmax rescaling, we operate by blocks and avoid having to read/write from HBM, while obtaining the correct output with no approximation.

2、FlashAttention-2:更好的算法、并行性和工作分区

(1)、Fewer non-matmul FLOPs减少非矩阵乘法FLOPs:基于使矩阵乘法更快的现代GPU+尽可能多地使用矩阵乘法FLOP

(2)、Better Parallelism更好的并行性:FlashAttention基于额外添加了序列长度维度的并行→加速为长序列场景

(3)、Better Work Partitioning更好的工作分区:

对比:FlashAttention(原版本将K和V划分给不同线程群需要线程同步写出中间结果)、FlashAttention-2(新版本将Q分片给不同线程群计算后直接与共享的K、V相乘得到输出)

3、新特性(头维度高达256+支持多查询注意力):兼容更多模型+同时支持多查询注意力【如MQA/GQA】→进一步减小键值缓存大小+提速

4、Attention Benchmark—Attention的基准:FlashAttention-2可达2倍FlashAttention、9倍PyTorch标准实现、训练GPT风格模型时的1.3倍实现

Attention forward + backward speed on A100 GPU:高达225 TFLOPs/s

Attention forward + backward speed on H100 GPU:多达335 TFLOPs/s

Baseline: Megatron-LM without FlashAttention. Megatron-LM now has an option to use FlashAttention.

5、Discussion and Future Work讨论与未来工作:FlashAttention-2使得同样成本训练更长文本模型【8k→16k】→未来将其应用于更多设备和数据类型

6、Acknowledgement致谢


FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning翻译与解读

地址

博客文章FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning | Princeton NLP Group

时间

2023年7月17日

作者

谷歌学者Tri Dao,坦福大学计算机科学博士,Together.AI的首席科学家

扩展Transformer的上下文长度是一个挑战—需要更长上下文的语言模型:GPT-4(32k)、MPT(65k)、Claude(100k)

Just within the last year, there have been several language models with much longer context than before: GPT-4 with context length 32k, MosaicML’s MPT with context length 65k, and Anthropic’s Claude with context length 100k. Emerging use cases such as long document querying and story writing have demonstrated a need for models with such long context. Scaling up the context length of Transformers is a challenge, since the attention layer at their heart has runtime and memory requirements that are quadratic in the input sequence length.

仅在过去的一年里,出现了几种比以前上下文更长的语言模型:GPT-4的上下文长度为32k, MosaicML的MPT的上下文长度为65k, Anthropic的Claude的上下文长度为100k。新兴的用例,如长文档查询和故事创作,已经证明了需要具有如此长上下文的模型扩展Transformer的上下文长度是一个挑战,因为它们核心的注意力层具有与输入序列长度成二次关系的运行时和内存要求

2022年发布FlashAttention(速注意力并减少其内存占用):比几线快2~4倍

A year ago, we released FlashAttention, a new algorithm to speed up attention and reduce its memory footprint—without any approximation. We’ve been very happy to see FlashAttention being adopted by many organizations and research labs to speed up their training & inference (see this page for a partial list). Even though FlashAttention was already 2-4x faster than optimized baselines at the time of its release, it still has quite a bit of headroom. FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s (e.g. up to 124 TFLOPs/s on A100 GPU).

一年前,我们发布了FlashAttention,这是一种新的算法,可以加速注意力并减少其内存占用,而无需任何近似。我们很高兴看到FlashAttention被许多组织和研究实验室采用,以加快他们的训练和推理(请查看此页面以获取部分列表)。尽管FlashAttention在发布时已经比优化后的基线快了2-4倍,但它仍然有相当大的提升空间。FlashAttention仍然不如优化的矩阵乘法(GEMM)操作快,仅达到理论最大FLOPs/s的25-40%(例如在A100 GPU上高达124 TFLOPs/s)。

今天,正式发布FlashAttention-2

In the past few months, we’ve been working on the next version, FlashAttention-2, that makes FlashAttention even better. Rewritten completely from scratch to use the primitives from Nvidia’s  CUTLASS 3.x and its core library  CuTe, FlashAttention-2 is about 2x faster than its previous version, reaching up to 230 TFLOPs/s on A100 GPUs. When used end-to-end to train GPT-style language models, we reach a training speed of up to 225 TFLOPs/s (72% model FLOP utilization). In this blogpost, we describe some of the bottlenecks of FlashAttention, and how we use better parallelism and work partitioning to get significant speedup.

FlashAttention-2 is available at:  flash-attention

在过去的几个月里,我们一直在研发下一个版本,即FlashAttention-2,使FlashAttention变得更加出色。FlashAttention-2完全重写,使用Nvidia的CUTLASS 3.x和其核心库CuTe的原语,比其前一个版本快大约2倍,在A100 GPU上最高可达230 TFLOPs/s。当用于端到端训练类似GPT的语言模型时,我们达到了高达225 TFLOPs/s(模型FLOP利用率为72%)的训练速度。在本文中,我们将描述FlashAttention的一些瓶颈,以及我们如何使用更好的并行性和工作分区来获得显著的加速。FlashAttention-2可在:flash-attention

1、FlashAttention回顾

FlashAttention的优点:是一种重新排序注意力计算的算法,利经典技术(tiling、重新计算,提速2~4倍)加速+从与序列长度成二次关系降低到与线性关系

FlashAttention is an algorithm that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. Tiling means that we load blocks of inputs from HBM (GPU memory) to SRAM (fast cache), perform attention with respect to that block, and update the output in HBM. By not writing the large intermediate attention matrices to HBM, we reduce the amount of memory reads/writes, which brings 2-4x wallclock time speedup.

FlashAttention是一种重新排序注意力计算的算法,并利用经典技术(tiling、重新计算)来显著加速它并将内存使用量从与序列长度成二次关系降低到与线性成正比的算法。tiling意味着我们从HBM(GPU内存)加载输入块到SRAM(快速缓存),针对该块执行注意力计算,并在HBM中更新输出。通过不将大型中间关注矩阵写入HBM,我们减少了内存读写的数量,从而带来2-4倍的时间加速。

FlashAttention的缺点:低占用率(GPU工作分区不够优化)、不必要的共享内存读写

However, FlashAttention still has some inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes.

然而,FlashAttention仍然存在一些低效性,这是因为在GPU上不同线程块和线程束之间的工作分区不够优化,导致低占用率或不必要的共享内存读写。

Diagram of FlashAttention forward pass: with tiling and softmax rescaling, we operate by blocks and avoid having to read/write from HBM, while obtaining the correct output with no approximation.

FlashAttention向前传递图:使用tiling和softmax重新缩放,我们按块操作,避免必须从HBM读/写,同时获得正确的输出,没有近似。

Diagram of FlashAttention forward pass: with tiling and softmax rescaling, we operate by blocks and avoid having to read/write from HBM, while obtaining the correct output with no approximation.

2、FlashAttention-2更好的算法、并行性和工作分区

(1)、Fewer non-matmul FLOPs减少非矩阵乘法FLOPs基于使矩阵乘法更快的现代GPU+尽可能多地使用矩阵乘法FLOP

We tweak the algorithm from FlashAttention to reduce the number of non-matmul FLOPs. This is important because modern GPUs have specialized compute units (e.g., Tensor Cores on Nvidia GPUs) that makes matmul much faster. As an example, the A100 GPU has a max theoretical throughput of 312 TFLOPs/s of FP16/BF16 matmul, but only 19.5 TFLOPs/s of non-matmul FP32. Another way to think about this is that each non-matmul FLOP is 16x more expensive than a matmul FLOP. To maintain high throughput, we want to spend as much time on matmul FLOPs as possible. We rewrite the online softmax trick used in FlashAttention to reduce the number of rescaling ops, as well as bound-checking and causal masking operations, without changing the output.

我们对FlashAttention算法进行了微调,以减少非矩阵乘法FLOPs的数量。这很重要,因为现代GPU拥有专门的计算单元(例如Nvidia GPU上的Tensor Cores),使矩阵乘法计算速度更快。例如,以A100 GPU为例,FP16/BF16的最大理论吞吐量为312 TFLOPs/s,而非矩阵FP32的最大理论吞吐量为19.5 TFLOPs/s。

另一种思考这个问题的方式是,每个非矩阵FLOP的成本是矩阵乘法FLOP的16倍。为了保持高吞吐量,我们希望尽可能多地使用矩阵乘法FLOP的时间。我们重新编写了FlashAttention中使用的在线softmax技巧,以减少重新缩放操作的数量,以及边界检查和因果掩码操作,而不改变输出。

(2)、Better Parallelism更好的并行性FlashAttention基于额外添加了序列长度维度的并行→加速为长序列场景

第一版FlashAttention的并行计算主要基于批量大小和头数量,但对于长序列而言(即小批量或小头数量情况),会导致GPU多处理器利用率低。FlashAttention-2在此基础上额外添加了序列长度维度的并行,更好地利用GPU多处理器,从而为长序列场景带来明显的速度提升。

The first version of FlashAttention parallelizes over batch size and number of heads. We use 1 thread block to process one attention head, and there are overall (batch_size * number of heads) thread blocks. Each thread block is scheduled to run on a streaming multiprocessor (SM), and there are 108 of these SMs on an A100 GPU for example. This scheduling is efficient when this number is large (say >= 80), since we can effectively use almost all of the compute resources on the GPU.

In the case of long sequences (which usually means small batch sizes or small number of heads), to make better use of the multiprocessors on the GPU, we now additionally parallelize over the sequence length dimension. This results in significant speedup for this regime.

FlashAttention的第一个版本在批量大小和头数上进行了并行化。我们使用1个线程块来处理一个关注头,总共有(批量大小*头数)个线程块每个线程块被安排在一个流多处理器(SM)上运行,例如A100 GPU上有108个这样的SM。当这个数字很大时(例如>= 80),这种调度是有效的,因为我们可以有效地使用GPU上的几乎所有计算资源。

对于长序列的情况(通常意味着较小的批量大小或较少的头数),为了更好地利用GPU上的多处理器,我们现在还在序列长度维度上进行了额外的并行化。这将为这种情况带来显著的加速。

(3)、Better Work Partitioning更好的工作分区:

FlashAttention-2改进线程块内线程群之间的工作分配方案,原版本将K和V划分给不同线程群需要线程同步写出中间结果,而新版本将Q分片给不同线程群计算后直接与共享的K、V相乘得到输出,消除了线程同步带来的内存读取开销,从而提升速度。

Even within each thread block, we also have to decide how to partition the work between different warps (a group of 32 threads working together). We typically use 4 or 8 warps per thread block, and the partitioning scheme is described below. We improve this partitioning in FlashAttention-2 to reduce the amount of synchronization and communication between different warps, resulting in less shared memory reads/writes.

即使在每个线程块中,我们也必须决定如何在不同的warp(一组32个线程一起工作)之间划分工作。

即使在每个线程块内部,我们还必须决定如何在不同的线程束和线程块之间分配工作。通常情况下,我们每个线程块使用4或8个线程束,并且下面描述了分区方案。在FlashAttention-2中,我们改进了这种分区,以减少不同线程束之间的同步和通信量,从而减少了共享内存的读写。

对比:FlashAttention(原版本将K和V划分给不同线程群需要线程同步写出中间结果)、FlashAttention-2(新版本将Q分片给不同线程群计算后直接与共享的K、V相乘得到输出)

For each block, FlashAttention splits K and V across 4 warps while keeping Q accessible by all warps. This is referred to as the “sliced-K” scheme. However, this is inefficient since all warps need to write their intermediate results out to shared memory, synchronize, then add up the intermediate results. These shared memory reads/writes slow down the forward pass in FlashAttention.

In FlashAttention-2, we instead split Q across 4 warps while keeping K and V accessible by all warps. After each warp performs matrix multiply to get a slice of Q K^T, they just need to multiply with the shared slice of V to get their corresponding slice of the output. There is no need for communication between warps. The reduction in shared memory reads/writes yields speedup.

对于每个块,FlashAttention将K和V分配给4个线程束,同时让Q对所有线程束可访问。这被称为“切片-K”方案。然而,这是低效的,因为所有线程束都需要将其中间结果写入共享内存,进行同步,然后将中间结果相加。这些共享内存的读写减慢了FlashAttention的向前传递

在FlashAttention-2中,我们改为将Q分配给4个线程束,同时让K和V对所有线程束可访问。每个线程束执行矩阵乘法以获取Q K^T的一个切片,然后它们只需将其与共享切片的V相乘,即可获得相应的输出切片。不需要在线程束之间进行通信。减少共享内存的读写带来了加速。

3、新特性(头维度高达256+支持多查询注意力):兼容更多模型+同时支持多查询注意力【如MQA/GQA】→进一步减小键值缓存大小+提速

New features: head dimensions up to 256, multi-query attention

FlashAttention-2支持的头数量上限提高到256,兼容更多模型,同时支持多查询注意力和分组查询注意力,这些变体可以进一步减小键值缓存大小,明显提升推理吞吐量。

FlashAttention only supported head dimensions up to 128, which works for most models but a few were left out. FlashAttention-2 now supports head dimension up to 256, which means that models such as GPT-J, CodeGen and CodeGen2, and StableDiffusion 1.x can use FlashAttention-2 to get speedup and memory saving.

This new version also supports multi-query attention (MQA) as well as grouped-query attention (GQA). These are variants of attention where multiple heads of query attend to the same head of key and value, in order to reduce the size of KV cache during inference and can lead to significantly higher inference throughput.

FlashAttention仅支持头维度高达128,这适用于大多数模型,但有些模型不能使用。FlashAttention-2现在支持头维度高达256,这意味着模型如GPT-JCodeGenCodeGen2以及StableDiffusion 1.x可以使用FlashAttention-2来获得加速和节省内存。

这个新版本还支持多查询注意力(MQA)以及分组查询注意力(GQA)。这些是注意力的变种,其中多个查询头注意力相同的键头和值头,以减小推断过程中的KV缓存大小,可以显著提高推断吞吐量。

4、Attention BenchmarkAttention基准FlashAttention-2可达2倍FlashAttention9倍PyTorch标准实现训练GPT风格模型时1.3倍实现

通过Benchmark测试,FlashAttention-2在不同设置下的注意力前向和反向传播速度可达2倍FlashAttention和9倍PyTorch标准实现,在A100GPU上最大运行速度达335TFLOPs/s,在end-to-end训练GPT模型时实现1.3倍速度提升。

We measure the runtime of different attention methods on an A100 80GB SXM4 GPU for different settings (without / with causal mask, head dimension 64 or 128). We see that FlashAttention-2 is around 2x faster than FlashAttention (as well as its other implementations in the xformers library and in Triton). Compared to a standard attention implementation in PyTorch, FlashAttention-2 can be up to 9x faster.

我们在A100 80GB SXM4 GPU上对不同设置(无/有因果掩码,头维度64或128)上的不同关注方法的运行时进行了测量。我们发现FlashAttention-2比FlashAttention快大约2倍(以及xformers库和Triton中的其他实现)。与PyTorch中的标准关注实现相比,FlashAttention-2的速度可以提高多达9倍。

Attention forward + backward speed on A100 GPU:高达225 TFLOPs/s

Attention forward + backward speed on A100 GPU

Just running the same implementation on H100 GPUs (using no special instructions to make use of new hardware features such as TMA and 4th-gen Tensor Cores), we obtain up to 335 TFLOPs/s.

Attention forward + backward speed on H100 GPU

只要在H100 GPU上运行相同的实现(不使用特殊指令来利用新的硬件特性,如TMA和第四代张量核心),我们可以获得多达335 TFLOPs/s的吞吐量。

Attention forward + backward speed on H100 GPU:多达335 TFLOPs/s

When used to train a GPT-style model end-to-end, FlashAttention-2 helps achieve up to 225 TFLOPs/s on A100 GPU (72% model FLOPs utilization). This is a 1.3x end-to-end speedup over an already very optimized model with FlashAttention.

当用于端到端训练GPT风格的模型时,FlashAttention-2可以在A100 GPU上实现高达225 TFLOPs/s的吞吐量(模型FLOP利用率为72%)。这相当于在已经经过优化的模型中使用FlashAttention获得了1.3倍的端到端加速。

Baseline: Megatron-LM without FlashAttention. Megatron-LM now has an option to use FlashAttention.

*Baseline: Megatron-LM without FlashAttention. Megatron-LM now has an option to use FlashAttention. We plan to integrate FlashAttention-2 to Megatron-LM in the near future.

*基线:Megatron-LM没有使用FlashAttention。Megatron-LM现在可以选择使用FlashAttention。我们计划在不久的将来将FlashAttention-2集成到Megatron-LM中。

5、Discussion and Future Work讨论与未来工作FlashAttention-2使得成本训练更长文本模型【8k→16k】→未来将其应用于更多设备和数据类型

FlashAttention-2运行速度提升2倍,可以用相同成本训练更长文本模型,未来工作计划将其应用于更多设备和数据类型,同时通过算法和低级优化相结合可能支持远超以往的长序列训练。

FlashAttention-2 is 2x faster than FlashAttention, which means that e.g. we can train models with 16k longer context for the same price as previously training a 8k context model. We’re excited about how this can be used to understand long books and reports, high resolution images, audio and video. FlashAttention-2 will also speed up training, finetuning, and inference of existing models.

In the near future, we plan to collaborate with folks to make FlashAttention widely applicable in different kinds of devices (e.g. H100 GPUs, AMD GPUs), as well as new data types such as fp8. As an immediate next step, we plan to optimize FlashAttention-2 for H100 GPUs to use new hardware features (TMA, 4th-gen Tensor Cores, fp8). Combining the low-level optimizations in FlashAttention-2 with high-level algorithmic changes (e.g. local, dilated, block-sparse attention) could allow us to train AI models with much longer context. We’re also excited to work with compiler researchers to make these optimization techniques easily programmable.

FlashAttention-2比FlashAttention快2倍,这意味着,例如,我们可以用与以前训练8k上下文模型相同的价格来训练具有16k更长上下文的模型。我们很高兴看到它可以用来理解长篇书籍和报告、高分辨率图像、音频和视频。FlashAttention-2还将加速现有模型的训练、微调和推理。

在不久的将来,我们计划与其他人合作,使FlashAttention在不同类型的设备(例如H100 GPU、AMD GPU)上广泛适用,以及新的数据类型,如fp8。下一步,,我们计划优化FlashAttention-2以适应H100 GPU,以利用新的硬件特性(TMA、第四代张量内核、fp8)。将FlashAttention-2中的低级优化与高级算法更改(例如局部、扩张、块稀疏关注)结合起来,可以使我们能够训练具有更长上下文的AI模型。我们也很高兴与编译器研究人员合作,使这些优化技术易于编程。

6、Acknowledgement致谢

We thank Phil Tillet and Daniel Haziza, who have implemented versions of FlashAttention in Triton and the xformers library. FlashAttention-2 was motivated by exchange of ideas between different ways that attention could be implemented. We are grateful to the Nvidia CUTLASS team (especially Vijay Thakkar, Haicheng Wu, and Andrew Kerr) for their CUTLASS library, in particular the CUTLASS 3.x release, which provides clean abstractions and powerful building blocks for the implementation of FlashAttention-2. We thank Driss Guessous for integrating FlashAttention to PyTorch. FlashAttention-2 has benefited from helpful discussions with Phil Wang, Markus Rabe, James Bradbury, Young-Jun Ko, Julien Launay, Daniel Hesslow, Michaël Benesty, Horace He, Ashish Vaswani, and Erich Elsen. Thanks to Stanford CRFM and Stanford NLP for the compute support. We thank Dan Fu and Christopher Ré for their collaboration, constructive feedback, and constant encouragement on this line of work of designing hardware-efficient algorithms. We thank Albert Gu and Beidi Chen for their helpful suggestions on early drafts of the FlashAttention-2 technical report.

我们要感谢Phil Tillet和Daniel Haziza,他们在Tritonxformers库中实现了FlashAttention的版本。FlashAttention-2的灵感来自于不同注意力执行方式之间的思想交流。我们非常感谢Nvidia CUTLASS团队(特别是Vijay Thakkar, Haicheng Wu和Andrew Kerr)的CUTLASS库,特别是CUTLASS 3。它为FlashAttention-2的实现提供了清晰的抽象和强大的构建块。我们感谢Driss Guessous将FlashAttention集成到PyTorch。flashatten2得益于与Phil Wang, Markus Rabe, James Bradbury, Young-Jun Ko, Julien Launay, Daniel Hesslow, Michaël Benesty, Horace He, Ashish Vaswani和Erich Elsen的有益讨论。感谢斯坦福CRFM和斯坦福NLP提供的计算支持。我们感谢Dan Fu和Christopher r<s:1>在设计硬件高效算法方面的合作、建设性的反馈和不断的鼓励。我们感谢Albert Gu和Beidi Chen在FlashAttention-2技术报告的早期草稿中提出的有益建议。

猜你喜欢

转载自blog.csdn.net/qq_41185868/article/details/133108384