PyTorch推出无CUDA加速推理,告别英伟达的时代

最近,PyTorch官方发布了关于无CUDA计算的实现方法,并对各个内核进行了微基准测试比较,探讨了未来如何进一步优化Triton内核,以缩小与CUDA的性能差距。

在训练、微调和推理大语言模型(LLM)时,使用英伟达的GPU和CUDA已经成为一种普遍做法。同样,在更广泛的机器学习领域,CUDA的依赖性也是显而易见,它为加速机器学习模型提供了显著的性能提升。

尽管CUDA在加速计算方面的主导地位依然稳固,且是英伟达的重要竞争壁垒,但一些新兴技术逐渐对其形成挑战。例如,OpenAI推出的Triton在可用性、内存开销和AI编译器堆栈构建等方面展现出一定的优势,并持续得到了开发。

最近,PyTorch宣布将探索“无英伟达CUDA参与的大模型推理”。在谈到为何要全面采用Triton时,PyTorch表示:“Triton为大模型在不同类型的GPU上运行提供了一条新的途径,包括英伟达、AMD、英特尔及其他基于GPU的加速器。”

此外,Triton还为GPU编程提供了更高的抽象层次,使得使用PyTorch编写高性能内核的速度超过了使用供应商特定API的效率。

图片

在PyTorch的博客中,讨论了如何使用流行的LLM模型(如Meta的Llama3-8B和IBM的Granite-8B Code)实现FP16推理的方法,其中所有计算都是通过OpenAI的Triton语言完成的。

在生成单个token的时间上,使用基于Triton内核的模型,PyTorch在英伟达H100 GPU上实现了Llama和Granite相较于CUDA内核主导工作流程的0.76-0.78倍性能,而在英伟达A100 GPU上则实现了0.62-0.82倍的性能。

图片

图 1. Llama3-8B和Granite-8B在英伟达H100和A100上的Triton与CUDA变体的推理吞吐量比较。设置:批大小 = 2,输入序列长度 = 512,输出序列长度 = 256

或许是时候真正告别英伟达了。

图片

Transformer块的构成

PyTorch团队首先对基于Transformer模型中的计算进行了细分,下面的图展示了典型Transformer块的“内核”。

图片

图 2

Llama3架构的核心操作总结如下:

  • 均方根归一化(RMSNorm)
  • 矩阵乘法:Fused QKV
  • RoPE
  • 注意力机制
  • 矩阵乘法:输出投影
  • RMSNorm
  • 矩阵乘法:Fused Gate + Up Projection
  • 激活函数:SiLU
  • 点乘(Element Wise Multiplication)
  • 矩阵乘法:Down Projection

这些操作均通过在GPU上执行一个(或多个)内核来计算。虽然不同Transformer模型中的每个内核的细节可能会有所不同,但核心操作仍然保持一致。例如,IBM的Granite 8B Code模型在MLP层中使用偏置,而Llama3则没有。这种差异确实需要对内核进行调整。通常,这些Transformer块是通过堆叠的方式构建而成,并通过嵌入层相互连接。

模型推理

典型的模型架构代码与PyTorch启动的python model.py文件共享。在默认的PyTorch Eager Execution模式下,这些内核都是通过CUDA执行的。为了实现Llama3-8B和Granite-8B的全程Triton推理,需要编写并集成手写的Triton内核,同时利用torch.compile生成Triton操作。首先,PyTorch用编译器生成的Triton内核替换较小的操作,接着,再用手写的Triton内核替换更复杂和昂贵的计算(如矩阵乘法和闪存注意力)。

Torch.compile能够自动为RMSNorm、RoPE、SiLU和点乘生成Triton内核。使用Nsight Systems等工具,可以观察到这些生成的内核在矩阵乘法和注意力之间显示为微小的深绿色内核。

图片

图 3. 使用torch.compile跟踪Llama3-8B,显示用于矩阵乘法和闪存注意力的CUDA内核。

在上述跟踪中,PyTorch团队注意到,在Llama3-8B模型中,占E2E延迟80%的主要操作是矩阵乘法和注意力内核,而这两者仍然是CUDA内核。因此,为了弥补剩余的性能差距,PyTorch团队用手写的Triton内核替换了矩阵乘法和注意力内核。

Triton SplitK GEMM内核

对于线性层中的矩阵乘法,PyTorch团队编写了一个自定义的FP16 Triton GEMM(通用矩阵-矩阵乘法)内核,利用了SplitK工作分解。

GEMM内核调优

为实现最佳性能,PyTorch团队使用穷举搜索方法来调整SplitK GEMM内核。Granite-8B和Llama3-8B的线性层具有如下形状:

图片

图 4. Granite-8B和Llama3-8B线性层权重矩阵的形状。

每个线性层都有不同的权重矩阵形状。为了获得最佳性能,必须针对每个形状轮廓调整Triton内核。在对每个线性层进行调优后,PyTorch能够实现Llama3-8B和Granite-8B相较于未调优的Triton内核1.20倍的E2E加速。

Flash Attention内核

PyTorch团队对现有的Triton闪存注意力内核进行了评估,使用了不同的配置,包括:

  • AMD Flash
  • OpenAI Flash
  • Dao AI Lab Flash
  • XFormers Flash
  • PyTorch FlexAttention

在Eager模式和编译模式下,PyTorch团队分别评估了每个内核的文本生成质量。下图5总结了不同Flash Attention内核的比较结果。

图片

上述图表总结了PyTorch观察到的开箱即用情况,并预计内核2到5在修改后能够满足上述标准。然而,这也表明,拥有一个可用于基准测试的内核往往仅仅是将其作为端到端生产内核的起点。

PyTorch团队选择在后续测试中使用AMD闪存注意力内核,该内核通过torch.compile进行编译,并在Eager和编译模式下产生清晰的输出。

为确保torch.compile与AMD闪存注意力内核的兼容性,PyTorch团队必须将其定义为PyTorch自定义算子。此外,封装更复杂的闪存注意力内核遵循以下两个步骤:

第一步是将函数封装为一个PyTorch自定义算子。

图片

第二步是向该算子添加一个FakeTensor内核,并在给定闪存输入张量的形状(q、k和v)时,计算闪存内核的输出形状。

图片

在将Triton闪存内核定义为自定义op后,PyTorch团队成功编译了该内核以实现端到端运行。

图片

图 6:在交换Triton矩阵乘法和Triton闪存注意力内核后,使用torch.compile的Llama3-8B跟踪。

从图中可以看出,在集成SplitK矩阵乘法内核后,torch op封装闪存注意力内核,然后运行torch.compile,即可实现100%Triton计算内核的前向传播。

端到端基准测试

PyTorch团队分别对在英伟达H100和A100(单GPU)上运行Granite-8B和Llama3-8B模型进行了端到端测试,使用了两种不同的配置来执行基准测试。

其中,Triton内核配置使用了:

  • Triton SplitK GEMM
  • AMD Triton Flash Attention

CUDA内核配置使用了:

  • cuBLAS GEMM
  • cuDNN Flash Attention - Scaled Dot-Product Attention (SDPA)

在典型推理设置下,Eager和torch编译模式的吞吐量和inter-token延迟如下图所示。

图片

图 7:H100和A100上Granite-8B和Llama3-8B单token生成延迟(批大小 = 2,输入序列长度 = 512,输出序列长度 = 256)。

总体来看,在H100上,Triton模型的性能最高可达CUDA模型的78%;在A100上可达82%。这些性能差距主要由矩阵乘法和闪存注意力内核的延迟造成。

微基准测试

下图8展示了在英伟达H100上运行Llama3-8B时,Triton和CUDA内核延迟的比较。输入为一个任意prompt(批大小 = 1,prompt序列长度 = 44),以解码延迟时间。

最终结果显示,Triton矩阵乘法内核相较于CUDA慢了1.2至1.4倍,而AMD Triton闪存注意力相较于CUDA SDPA慢了1.6倍。

这些结果突显了需要进一步提升GEMM和Flash Attention等核心原语内核的性能。最近的研究(如FlashAttention-3、FlexAttention)已提出了更好地利用底层硬件与Triton的方法,PyTorch希望能够在这些基础上实现更大的加速。为此,PyTorch团队正在验证FlexAttention的端到端性能。目前,FlexAttention的初步微基准测试结果显示,在查询向量较小的情况下,有望实现更长的上下文以及解码问题形状。

图片

图 9:在英伟达H100 SXM5 80GB上进行FlexAttention内核基准测试(批大小 = 1,最大头数 = 32,头维数 = 128)。

未来工作

展望未来,PyTorch团队计划继续探索优化矩阵乘法的途径,以更有效地利用硬件,并实现基于Triton的方法的更大加速。

在闪存注意力方面,PyTorch团队计划研究FlexAttention和FlashAttention-3等内核中使用的技术,以进一步缩小Triton与CUDA之间的性能差距。同时,他们也将探索端到端的FP8 LLM推理。