长亭百川云 - 文章详情

LLM推理加速:decode阶段的Attention在GPU上的优化

阿里技术

196

2024-07-27

这是2024年的第53篇文章

( 本文阅读时间:15分钟 )

随着大语言模型(Large Language Models,LLMs)在各领域的广泛应用,如何以低成本构建高吞吐、低延迟的推理服务成为了一个紧迫的问题。考虑到LLM在GPU上推理时参数量和计算量较大以致于单流执行就可以充分利用GPU资源,我们可以把LLM的推理延时分解到kernel level,因此,进一步的,不考虑时间占比小的kernel计算后,LLM的延时优化也就相应的分解成GEMM和Attention的kernel优化。

RTP-LLM是阿里巴巴智能引擎团队开发的大模型推理加速引擎,作为一个高性能的大模型推理解决方案,它已被广泛应用于阿里内部。在这篇文章里,我们将基于RTP-LLM的实践,介绍decode阶段的Attention在GPU上是如何优化的。

01

背景

我们比较熟悉的Attention计算如下图所示,包含Q与K相乘,其结果在mask后做softmax,然后与V相乘,得到Attention的结果。在LLM推理的decode阶段,由于加入KV Cache优化,一次迭代只需要计算新增的一个token,因此计算可以变化为当前step的Q(seq == 1)与K Cache、V Cache做计算。





计算过程中各tensor的shape可以表示为:

Q

(B, H, 1, D)

K Cache

(B, H_kv, S, D)

V Cache

(B, H_kv, S, D)

Q * K Cache

(B, H, 1, S)

O

(B, H, 1, D)

参数的解释如下表:

B

Batch size / num_seqs

H

head_num

H_kv

head_num_kv

S

seq length of KV Cache

D

head_size

在本文的分析中,我们考虑简单的Multi Head Attention实现,即H == H_kv。

我们希望以一个kernel实现上图的计算。出于性能考虑,将前一步的BiasAdd,Rotary Embedding也一起融合。因此这个kernel接受的输入是经过QKV GEMM的Q、K、V,在kernel中完成BiasAdd,然后Q和K会一起做Rotary Embedding。当前的K和V会分别与之前计算得到的KV Cache做拼接,扩展成(B, H, S, D)的KV Cache。然后Q与K Cache相乘,得到的结果在S维计算SoftMax,再与V Cache相乘,得到最后的输出。

简化的代码示例如下:

#(B, 3, H, D) -> 3 * (B, H, 1, D)
Q, K, V = add(QKV_buffer, bias)
#(B, H, 1, D) -> (B, H, 1, D)
Q, K = rotary_embedding(Q, K)
#(B, H, 1, D) -> (B, H, S, D)
K, V = concat(past_KV, K, V)
#(B, H, 1, D) * (B, H, S, D) -> (B, H, 1, S)
res = matmul(Q, K)/ sqrt(self.head_dim)
#(B, H, 1, S) -> (B, H, 1, S)
res = =softmax(res, dim=-1)
#(B, H, 1, S) * (B, H, S, D) -> (B, H, 1, D)
out = matmul(res, V)

在整个计算过程中,BiasAdd、Rotary Embedding相对计算量较小,对kernel的latency影响较小,因此下文省略这一部分的分析。

02

计算分析

我们以当前的TensorRT-LLM中Masked Multi Head Attention(MMHA)的实现为例,分析当前的MMHA是怎么实现高性能。

涉及到GPU并行计算,我们首先需要考虑的是任务划分。对于这个场景,任务划分实际上是清晰的:B和H是并行维度,在执行过程中的Q*K和QK*V,都可以理解成一个batch size = B * H的Batch GEMV。而SoftMax又是一个Reduce操作,因此单个GEMV的计算最好尽量在一个block内完成。因此,MMHA比较基础的任务划分大概是:

dim3 grid(B, H, 1);
dim3 block(THREAD_PER_BLOCK, 1, 1);

这里的THREAD_PER_BLOCK是指每个block用多少threads来完成一个head在S上的计算。通常更多的threads会更提高每个SM的active warps以更好的利用计算资源,增加load指令以提高数据load效率,因此我们希望THREAD_PER_BLOCK越大越好(最好接近1024)。但由于kernel整体计算逻辑较为复杂,寄存器用量较大,threads可能会收到寄存器总量的限制;且在寄存器总量的限制下,我们可以简单的认为每个SM上只有一个active block。

基于这种划分,我们继续考虑每个block是如何计算。传入kernel的QKV buffer实际的layout是(B,3, H,D),在TensorRT-LLM的实现中,会先load当前step的Q和K并计算BiasAdd和ROPE,并将这一步得到的K Cache写回global buffer。完成这些计算后,因为数据还在寄存器中,会直接计算对应的QK dot。由于这些计算的耗时较短,我们略过这一部分分析,直接看看TensorRT-LLM是怎么计算Q * K Cache的。

Q乘K Cache的计算在D上累加。假设我们用half存KV Cache,用float做乘累加,为了保证load效率,每个thread会load连续的16bytes数据,也就是8个elements。对于常见的D==128来说,需要16个threads完成一个head的计算。可以认为给block中的threads进行了分组,每组16个threads负责一个head的计算,其中每个threads读8个elements,并完成这8个elements对应的乘累加,然后这组threads间通过warp内的shuffle完成当前head的计算,并将计算结果存到smem中。组和组在S上展开。





接下来计算SoftMax,由于前面的计算保证了SoftMax需要的输入都在当前block内的smem中,通过Block Reduce Max和Block Reduce Sum就可以完成SoftMax的计算。

乘V Cache的计算思路与上文乘K Cache非常类似,略有不同的是这一步计算需要在S上累加。依然将threads分组,每组16个threads负责一个head, 每个thread负责8个elements的计算。由于需要在S上累加,因此每个thread需要保存当前所计GPUsde算的8个elements的部分累加和。最后借助smem,将不同threads上的部分和累加,得到Attention的输出。





在计算过程中,qk dot除了hfma计算外,也可以调用hmma来完成单个head的计算。但由于kernel的性能瓶颈在访存上,dot用哪种计算方式对性能的影响不大;我们的测试也验证了这个结论。

上文的分析中依然省略了一些细节。具体的,比如我们现在通常用paged KV Block Array来存储KV Cache,也就是KV Cache可以在S维度上不连续,以便在S不断增长时动态的分配buffer。但paged的存储并不改变D维的连续,因此也不影响上文的分析。此外,每个thread在load KV Cache时会多load一部分存进本地的寄存器,以尽可能的将load数据与dot计算overlap。

主流框架如vllm,xformers等对MMHA的实现和优化思路都是比较类似的,仅在细节处略有差异。TensorRT-LLM在mmha外还实现了XQA以继续优化decode阶段Attention的计算,但由于代码未开源,本文也不做分析。

03

改进与优化

当然上文分析到的简单优化在实际应用中还是不那么够用的,最常见的就是小B和长S场景。

考虑到实际的GPU资源,如A100有108个SM,且每个SM上只有一个block(也就是只计算一个head),当B * H恰好占满108(或108的整数倍)个SM时,可以认为占用率是比较高的。以7B模型,或者72B模型2TP举例,H = 32,当B = 3时,占用率是88.9%;而当B = 4时,就会因必须打两轮而带来占用率的下降到59%;当B = 1时,占用率就会低到30%了。这个时候如果S比较大,我们就会发现,大部分的device资源还空闲着,也不得不一起等待部分SM完成一个时间很长的计算。

针对这种情况,我们把S也分配到grid dim上,资源分配也就改为:

dim3 grid(B, H, S_tile);
dim3 block(THREAD_PER_BLOCK, 1, 1);

在这种任务划分下,结合上文分析,假设长seq每个SM上仅有一个active block,则waves可以计算为:





当waves越接近ceil值,意味着device occupancy会越高。在小B大S的场景下,如果在S切分,也就是S_tile > 1,有利于增加occupancy。在这种情况下,S_tile个block共同完成一个head在S上的计算,每个block负责S / S_tile的计算,block间的reduce通过开辟额外的global buffer来完成。这种模式下,新增的global读写会带来有额外的耗时,但因为增加了device occupancy,因此在小B大S的场景下有明显的性能提升。这也就是flash decoding的思路,且在各框架均有支持。

除了性能的考虑外,超长seq也必须走进这种实现。由于Q * K的结果需要在S上做reduce,也就是smem需要存下对应大小的中间数据,根据kernel实现,输入类型是half,以float累加,可以估计算为6 * S。而根据A100每个SM实际可用smem是163KB计算,最大可支持的S在27K左右。当输入大于这个值时,我们必须在seq做切分,以保证kernel的计算。

另一种需要做不同的任务划分的场景是GQA。在GQA的计算下,每个head的KV Cache会对应于多个head的Q,为了避免KV Cache的重复load,资源分配应该改为,并基于此做计算上的调整。

dim3 grid(B, H_kv, S_tile);
dim3 block(THREAD_PER_BLOCK, 1, 1);

除了优化任务划分,MMHA的优化还可以在以下方面继续展开:

1)优化寄存器用量可能达到更高的占用率(可以在一个SM上launch多个block或者增大每个block的threads);

2)继续调整KV Cache的load行为,让计算和数据读取进一步overlap以缓解memory bound的场景;

3)在大B加上GQA,Attention会走到compute bound,需要调整计算模式以更好的利用tensor core加速计算等等。

我们将持续探索和实践,以更灵活、更具拓展性的优化策略来面对日益多样化和复杂的应用场景。优化后的kernel会开源在RTP-LLM中,欢迎大家交流共建。

参考链接

[01] TensorRT-LLM

https://github.com/NVIDIA/TensorRT-LLM

[02] vllm

https://github.com/vllm-project/vllm

[03] xformers

https://github.com/facebookresearch/xformers

[04] flash decoding

https://crfm.stanford.edu/2023/10/12/flashdecoding.html

[03] RTP-LLM

https://github.com/alibaba/rtp-llm

欢迎留言一起参与讨论~

相关推荐
关注或联系我们
添加百川云公众号,移动管理云安全产品
咨询热线:
4000-327-707
百川公众号
百川公众号
百川云客服
百川云客服

Copyright ©2024 北京长亭科技有限公司
icon
京ICP备 2024055124号-2