如何扩展你的模型:中文译本

基于已批准的 12 章中文 Markdown 合并生成。公式由 MathJax 渲染,代码、链接、引用与章节结构保留在单一 HTML 文档中。

第 1 章

Roofline 全解析

《How To Scale Your Model》第 1 部分(第 0 部分:Introduction | 第 2 部分:TPUs

当我们在硬件上运行算法时,会同时受到三方面约束:计算机做数学运算的速度(OPs/second)、搬运数据的带宽(bytes/second),以及存储数据所需的总内存(bytes)。这些“roofline”约束让我们能够为某个给定计算的耗时给出上界和下界。

目录

时间到底花在哪?

先从一个极其简单的问题开始:为什么一个算法耗时是 50ms,而不是 50s 或 5ms?模型内部究竟发生了什么,才会花去这么多时间?我们又应该预期它会花多长时间?

计算:深度学习模型本质上可以看作是一堆矩阵乘法,而每个矩阵乘法又由浮点乘法和加法这样的“操作”(FLOPs)组成。加速器的速度决定了这些计算要花多久:

\[ \begin{equation} T_\text{math} = \frac{\text{Computation FLOPs}}{\text{Accelerator FLOPs/s}} \end{equation} \]

例如,NVIDIA H100 大约能提供 9.89e14 bfloat16(bf16bfloat16 的简写,一种机器学习里常见的 16 位浮点格式)FLOPs/s,而 TPU v6e 大约能提供 9.1e14 FLOPs/s。H100 和 B200 通常只能达到标称峰值 FLOPs 的 80% 到 85%,而 TPU 在正常使用中通常能更接近 95%。

这意味着,在 H100 上执行 1e12 FLOPs 大约需要 1e12 / 9.89e14 = 1.01ms,而在 TPU v6e 上大约需要 1e12 / 9.1e14 = 1.1ms。注意,这些芯片的定价并不相同,这里的比较没有按成本归一化。

芯片内通信:在单个加速器内部,张量需要在加速器内存(HBM)与计算核心之间传输。你会看到这条链路的带宽被称为 “HBM bandwidth”(NVIDIA 也称之为 “memory bandwidth”)。在 H100 上,这个数大约是 3.35TB/s,而在 TPU v6e 上,这个数大约是 1.6TB/s

芯片间通信:当我们把模型分布到多个加速器上时,张量经常需要在它们之间传输。具体到硬件上,往往有几种选择(ICI、DCN、PCIe),它们各自的带宽不同。

无论通信发生在芯片内还是芯片间,我们都用 bytes/s 来衡量,并用下式估计总通信时间:

\[ \begin{equation} T_\text{comms} = \frac{\text{Communication Bytes}}{\text{Network/Memory Bandwidth Bytes/s}} \end{equation} \]

通常情况下(但不总是这样),单芯片内部的计算可以与芯片内、芯片间通信重叠。因此,我们可以用计算时间与通信时间中的较大者来给训练和推理时间做下界,也可以用两者之和给出上界。实践中,我们通常针对两者的最大值进行优化,因为这样代数更简单,而且通过让通信与计算重叠,我们通常可以逼近这个边界。

如果按最大值来优化,那么上下界之间最多只差 2 倍,因为

$T_\text{math} + T_\text{comms} \leq 2 * \max(T_\text{math}, T_\text{comms})$。

在此基础上,若要进一步提高精度,就需要对“重叠区域”和各种额外开销建模,而这通常可以通过对你的具体模型和目标系统做 profiling 来获得依据。

\[ \begin{equation} T_\text{lower}=\max(T_\text{math}, T_\text{comms}) \end{equation} \] \[ \begin{equation} T_\text{upper} = T_\text{math} + T_\text{comms} \end{equation} \]

如果我们假设通信与计算可以完美重叠,那么当 $T_\text{math} > T_\text{comms}$ 时,我们就能让硬件满负荷利用。我们称这种情况为“计算受限”(compute-bound)。当 $T_\text{comms} > T_\text{math}$ 时,我们往往是“通信受限”(communication-bound),此时至少会有一部分加速器 FLOPs/s 被浪费在等待数据传输上。判断一个操作究竟是计算受限还是通信受限的一种方式,是看它的“算术强度”(arithmetic intensity,也叫 operational intensity)。

定义:一个算法的算术强度,是它执行的总 FLOPs 与它需要通信的字节数之比,无论这些通信发生在芯片内还是芯片间。

\[ \begin{equation} \text{Arithmetic Intensity} = \frac{\text{Computation FLOPs}}{\text{Communication Bytes}} \end{equation} \]

算术强度衡量的是某个操作“每字节对应多少 FLOPs”。一阶近似地看,当算术强度较高时,$T_\text{math}$ 相比 $T_\text{comms}$ 会更大,我们通常就能用满大部分可用 FLOPs。反之,如果不是这样,我们就会把更多时间花在通信上,并浪费 FLOPs。发生这种切换的点,就是硬件的“峰值算术强度”,也就是峰值加速器 FLOPs/s 与加速器带宽之比。

\[ \begin{align*} T_\text{math} > T_\text{comms} \Leftrightarrow \frac{\text{Computation FLOPs}} {\text{Accelerator FLOPs/s}} > \frac{\text{Communication Bytes}}{\text{Bandwidth Bytes/s}} & \\ [0.5em] \Leftrightarrow \frac{\text{Computation FLOPs}}{\text{Communication Bytes}} > \frac{\text{Accelerator FLOPs/s}}{\text{Bandwidth Bytes/s}} & \\ [0.5em] \Leftrightarrow \text{Intensity}(\text{Computation}) > \text{Intensity}(\text{Accelerator}) & \end{align*} \]

其中 $\text{Intensity}(\text{Accelerator})$ 指的就是加速器达到峰值 FLOPs/s 时所对应的算术强度。以 TPU v5e 的 MXU 为例,这个值大约是 240 FLOPs/byte,因为 TPU 能执行 1.97e14 FLOPs/s,并且能从 HBM 以 8.2e11 bytes/s 的速度加载数据。这里特意写明是 MXU,因为 TPU 还有其他执行单元,例如负责逐元素操作的 VPU,它们的峰值 FLOPs/s 是不同的。

这意味着,如果某个算法的算术强度低于 240 FLOPs/byte,那么它就会受限于字节加载,因此无法很好地利用硬件。这里还有一个前提:该算法的权重是从 HBM 加载,并且运行在 MXU 上。正如下一部分会讨论的,有时我们可以把参数存到带宽更高的 VMEM 中。很多算法也运行在 VPU 上,而不是 MXU,它们的性能特征不同。下面看这样一个例子:

例子(点积):为了用 bfloat16 精度计算两个向量的点积,x • y: bf16[N], bf16[N] → bf16[1],我们需要从内存中加载 $x$ 和 $y$,每个向量各占 $2 * N = 2N$ 字节;执行 $N$ 次乘法和 $N - 1$ 次加法;最后把 2 个字节写回 HBM。

\[ \begin{equation} \text{Intensity}(\text{dot product}) = \frac{\text{Total FLOPs}}{\text{Total Bytes}} = \frac{N + N - 1}{2N + 2N + 2} = \frac{2N - 1}{4N + 2} \rightarrow \frac{1}{2} \end{equation} \]

当 $N\rightarrow\infty$ 时,上式趋于 $\frac{1}{2}$。因此,点积的算术强度是 $\frac{1}{2}$;换句话说,点积每加载 1 个字节,只做 0.5 次浮点运算。这意味着它的算术强度低于硬件本身,因此会是通信受限的。上面那个 240 并不是这里正确的比较对象,因为正如下一部分会看到的,点积是在 VPU 上执行,而不是 MXU 上执行。

TPU v5p 的 VPU 大约能做到 7e12 FLOPs/s,因此它的临界算术强度大约是 3,这意味着这里仍然多少有些通信受限。无论如何,关键在于:它的算术强度既低又基本是常数,因此在大多数硬件上都很难进入计算受限。

Roofline 的可视化

我们可以用 roofline 图来可视化内存与计算之间的权衡。roofline 图的横轴是某个算法的算术强度,纵轴则是该算法在目标硬件上理论可达到的峰值 FLOPs/s(吞吐)。下面是一个对数-对数坐标图的例子:

图:一个 roofline 图示例,展示了两个具有不同算术强度的算法(Algo 1 和 Algo 2),以及它们在不同带宽(BW1 和 BW2)下各自对应的理论峰值吞吐。在红色区域中,算法在两种带宽下都受带宽限制,因此浪费了部分硬件峰值 FLOPs/s;黄色区域表示它只在较低带宽(BW1)下受带宽限制;绿色区域表示它在所有带宽下都受计算限制。

在这里,我们已经达到了加速器的峰值 FLOPs/s,因此继续提高带宽或提高算术强度都不会再带来收益。

在上图中,随着算术强度增加(从左往右移动),我们一开始会看到算法性能(以 FLOPs/s 计)线性增长,直到达到该硬件的临界算术强度;对于 TPU v5e,这个值是 240。任何算术强度低于这个值的算法,都会受带宽(BW)限制,并受峰值内存带宽约束(红色区域);位于右侧的算法则可以把 FLOPs 吃满(绿色区域)。

这里,Algo 1 是通信受限的,只利用了总硬件 FLOPs/s 中的一部分;Algo 2 则是计算受限的。一般来说,想提升一个算法的性能,要么提高它的算术强度,要么提高可用内存带宽(例如从 BW1 提升到 BW2)。

矩阵乘法

来看我们很快就会最熟悉的算法:矩阵乘法(也叫 matmul)。我们写作 $X * Y \rightarrow Z$,其中 $X$ 的形状是 $\text{bf16}[B, D]$,$Y$ 的形状是 $\text{bf16}[D, F]$,$Z$ 的形状是 $\text{bf16}[B, F]$。为了完成这个 matmul,我们需要加载 $2DF + 2BD$ 字节,执行 $2BDF$ FLOPs,并写回 $2BF$ 字节。严格来说,实际执行的是 $BF \times (2D - 1)$ FLOPs,不过这里的近似已经足够接近。这来自 $BDF$ 次乘法和 $BF * (D - 1)$ 次加法。第 4 部分里会有更多细节。

虽然 matmul 的输出从技术上说是 float32,但我们通常会在拷回 HBM 前把它 cast 回 bfloat16。因此:

\[ \begin{equation} \text{Intensity}(\text{matmul}) = \frac{2BDF}{2BD + 2DF + 2BF} = \frac{BDF}{BD + DF + BF} \end{equation} \]

如果我们假设“batch size” $B$ 相对于 $D$ 和 $F$ 都很小,就可以得到一个很好的简化:

\[ \begin{equation} \frac{BDF}{BD + DF + BF} \approx \frac{BDF}{DF} = B \end{equation} \] \[ \begin{equation} \text{Intensity}(\text{matmul}) > \text{Intensity}(\text{TPU}) \implies B > \frac{1.97e14}{8.20e11} = 240 \end{equation} \]

对于 Transformer 的 matmul 来说,这是一个合理假设,因为我们的本地(per-replica)batch size 通常满足 $B < 1024$ 个 token(不是序列),而 $D$ 和 $F > 8000$。因此,一个很简单的经验法则是:当每个 replica 上的 batch size 超过 240 个 token 时,我们通常就会变成计算受限。

这里说 per-replica,是因为如果我们通过某种模型切分来增加参与同一个 matmul 的芯片数,那么可用计算能力和内存带宽会按相同比例扩展。因此,这个临界 batch size 说的是“每一份独立模型权重副本”的规模。

结论:对于 bfloat16 matmul,要想在大多数 TPU 上进入计算受限,每个 replica 的 token batch size 需要大于 240。注意,这里的 batch size 不是通常所说以“序列数”来度量的 batch size。事实证明,大多数 roofline 只取决于 token 数量,而不取决于这些 token 属于同一条序列还是不同序列。

例如,如果你有 512 条序列、每条序列 4096 个 token,并在 128 张 GPU 上训练,那么总 batch size 就是 512 * 4096 = 2M 个 token,而本地 batch size 是 16k token。

这里还有几个值得注意的 caveat,我们会在下面的题目里进一步讨论,特别是关于量化的情况(例如,如果我们把 activation 做了量化,但仍执行全精度 FLOPs)。不过,这依然是一个很好记的经验规则。对于 GPU,这个数字会稍高一点(更接近 300),但总体结论是一样的。

当我们把一个大 matmul 分解成更小的 matmul 时,tile 的大小也会影响结果。做大矩阵乘法时,我们必须把它拆成更小的块,以便放进 VMEM/SMEM/TMEM 这些带宽更高的片上内存。这会导致某些块被重复加载,因此“我们只加载 $O(N^2)$ 字节”就不再完全成立。考虑一个 $(m, k) \cdot (k, n)$ 的 matmul,其 tile 尺寸分别为 $bm$、$bk$、$bn$。令 $tm = m / bm$,其余依此类推。

那么总 FLOPs 为:

$2 \cdot tm \cdot tn \cdot tk \cdot bm \cdot bn \cdot bk$

总字节数为:

$2 \cdot tm \cdot tn \cdot (tk \cdot (bm \cdot bk + bk \cdot bn) + 2 \cdot bm \cdot bn)$

忽略最后一项后,算术强度就是:

$bm \cdot bn / (bm + bn)$

这与前面的结论相似。更底层的 GPU 和 TPU 细节,我们会在下一部分讨论。

网络通信 roofline

到目前为止,我们讨论的 roofline 都是内存带宽 roofline,而且都发生在单芯片内部。但这不应被视为一条普遍规律。事实上,本书里我们真正关心的大多数 roofline,都涉及芯片之间的通信:通常是那些矩阵已经被分片到多块 TPU 上的矩阵乘法。

举一个稍微有些刻意的例子。假设我们要把两个大矩阵相乘:$X\sim \text{bfloat16[B, D]}$ 和 $Y \sim \text{bfloat16[D, F]}$,并把它们沿 $D$ 维平均切到 2 个 TPU/GPU 上。为了完成这个乘法(正如我们会在第 3 部分看到的),我们可以让每个 TPU 各自计算一半矩阵(TPU 0 上做 A = X[:, :D // 2] @ Y[:D // 2, :],TPU 1 上做 B = X[:, D // 2:] @ Y[D // 2:, :]),然后把得到的“部分和”复制到另一块 TPU 上,再把两边加起来。

假设每个方向都能以 4.5e10 bytes/s 复制数据,并且每片芯片都能执行 1.97e14 FLOPs/s。那么 $T_\text{math}$ 和 $T_\text{comms}$ 分别是多少?

$T_\text{math}$ 显然只有之前的一半,因为每个 TPU 只做了一半工作。也就是说,这里忽略了把两个部分和相加所需的 FLOPs(再做一次 BF 加法),但这基本可以忽略不计。

\[ T_\text{math} = \frac{2BDF}{2 \cdot \text{Accelerator FLOPs/s}} = \frac{BDF}{1.97e14} \]

那么 $T_\text{comms}$ 呢?这里它指的是芯片间的通信时间。它就是总发送字节数除以网络带宽,也就是:

\[ T_\text{comms} = \frac{2BF}{\text{Network Bandwidth}} = \frac{2BF}{4.5e10} \]

因此,当

$\text{Intensity}(\text{matmul (2-chips)}) > \text{Intensity}(\text{TPU w.r.t. inter-chip network})$

时,我们就进入了计算受限(这里说的是相对于跨芯片网络而言)。等价地:

$\frac{BDF}{2BF} = \frac{D}{2} > \frac{1.97e14}{4.5e10} = 4377$

也就是:

$D > 8755$

注意,与前面不同,这次的临界阈值取决于 $D$,而不是 $B$。可以想一想为什么会这样。

这只是这样一种例子,但它说明了:这种 roofline 对判断我们是否应该把一个操作并行到多个 TPU 上至关重要。

几个动手题

问题 1【int8 matmul】:假设我们要用 int8 精度(每个参数 1 字节)而不是 bfloat16(每个参数 2 字节)来执行矩阵乘法 $X[B, D] \cdot_D Y[D, F] \rightarrow Z[B, F]$,因为 TPU/GPU 在更低精度下能更快地执行 matmul。

这里以及下文中,我们用记号 $A \cdot_D B$ 表示这个乘法是在 $D$ 维上做 contraction。这是对 einsum 记号的一种非严格借用。

  1. 需要从内存加载多少字节?需要写回多少字节?
  2. 总共执行了多少 OPs?
  3. 算术强度是多少?
  4. 对 $T_\text{math}$ 和 $T_\text{comms}$ 的 roofline 估计是什么?整个操作运行时间的合理上下界分别是什么?

假设 HBM 带宽为 8.1e11 bytes/s,而 int8 峰值 OPs/s 为 3.94e14(大约是 bfloat16 的 2 倍)。

点击此处查看答案。
  1. 因为我们把参数存成 int8,所以每个参数占 1 个字节,因此从 HBM 加载的是 $BD + DF$ 字节,写回的是 $BF$ 字节。
  2. 这和 bfloat16 的情况相同,只不过理论上 int8 OPs/s 更快。所以总 FLOPs 仍然是 $2BDF$。
  3. 算术强度为 $2BDF / (BD + DF + BF)$。如果像上面那样假设 $B \ll D$ 且 $B \ll F$,那么算术强度就是 $2B$,也就是说经验规则变成了 $B > \text{HBM int8 arithmetic intensity} / 2$。代入给定数值,这个 int8 强度是 3.94e14 / 8.1e11 = 486,因此规则变成 $B > 486 / 2 = 243$。注意,这几乎没有变化。
  4. $T_\text{math} = 2BDF / 3.94e14$,$T_\text{comms} = (BD + DF + BF) / 8.1e11$,因此合理的下界是 $\max(T_\text{math}, T_\text{comms})$,上界是 $T_\text{math} + T_\text{comms}$。

问题 2【int8 + bf16 matmul】:在实践中,我们经常对权重和激活做不同的量化,也就是说,权重可能以很低精度存储,但激活(以及计算本身)仍保持较高精度。假设我们把权重量化成 int8,但让激活和计算保持 bfloat16。那么在多大的 batch size 下会进入计算受限?假设 bfloat16 峰值 FLOPs/s 为 1.97e14

提示:具体来说,这里指的是 bfloat16[B, D] * int8[D, F] -> bfloat16[B, F],其中 $B$ 是“batch size”。

点击此处查看答案。

同样假设 $B$ 很小,此时我们有 $2BDF$ 个 bfloat16 FLOPs,但权重只有 $DF$ 个字节(而不是 bfloat16 情况下的 $2DF$)。这意味着,当 $2B > 240$ 时我们就会进入计算受限,也就是 $B > 120$。这个阈值低了很多,这说明如果我们能做 int8 权重量化(通常相当容易),但仍执行 bfloat16 FLOPs,那么在效率上就能得到明显收益(当然,直接使用 int8 OPs 会更好)。

问题 3:沿用问题 2 的设定,分别对 $F = D = 4096$ 和 $F = D = 1024$,画出峰值 FLOPs/s 关于 $B$ 的 roofline 图。这里请使用精确的加载字节数,而不是近似值。

点击此处查看答案。

题目中的图如下:

Question 3 roofline plot

注意,这两个模型最终都会达到硬件峰值 FLOPs/s,但更大的 D/F 会更早达到这一点。D=F=1024 几乎会让临界 batch size 翻倍。生成这张图的代码如下:

import matplotlib.pyplot as plt
import numpy as np

bs = np.arange(1, 512)

def roofline(B, D, F):
  total_flops = 2*B*D*F
  flops_time = total_flops / 1.97e14
  comms_time = (2*B*D + D*F + 2*B*F) / 8.2e11
  total_time = np.maximum(flops_time, comms_time)
  return total_flops / total_time

roofline_big = roofline(bs, 4096, 4096)
roofline_small = roofline(bs, 1024, 1024)
plt.figure(figsize=(8, 4))
plt.plot(bs, roofline_big, label='F=D=4096')
plt.plot(bs, roofline_small, label='F=D=1024')
plt.legend()
plt.xlabel('batch size')
plt.ylabel('peak bfloat16 FLOPs/s on TPU v5e')
plt.grid()

问题 4:如果我们想执行 $\text{int8[B, D]} *_D \text{int8[B, D, F]} \rightarrow \text{int8[B, F]}$,也就是设想 batch 中每个元素都对应一个不同的矩阵,那么这个操作的算术强度是多少?

点击此处查看答案。

先从总 FLOPs 和总通信量开始看。

  1. 总 FLOPs:FLOPs 基本和前面一样,因为我们做的是相同数量的 $BD \times DF$ matmul(第 4 部分会进一步讨论)。因此这里仍然是 $2BDF$。
  2. 总通信量:这里的通信要多得多,为 $BD + BDF + BF$。
  3. 因此,算术强度变成了 $2BDF / (BD + BDF + BF)$。由于分母里主导项是 $BDF$,它大约等于 $2$。也就是说,它不再依赖 batch size,而几乎是常数。这是个坏消息,因为这意味着无论 batch size 多大,我们基本都会是通信受限的。

问题 5【GPU 的内存 roofline】:利用 NVIDIA H100 SXM 的规格页,计算 bfloat16 矩阵乘法会在多大的 batch size 下变成计算受限。注意,Tensor Core FLOPs 数字是实际值的 2 倍,因为它们只有在结构化稀疏下才能达到。

点击此处查看答案。

从规格页中我们可以看到,标注的 bfloat16 FLOPs 值是 1.979e15 FLOPs/s,并带有一个注明 “with sparsity” 的星号。不考虑稀疏性时,真实值应当减半,也就是接近 1e15 FLOPs/s。内存带宽是 3.35TB/s,也就是 3.35e12 bytes / second。因此,$B_\text{crit}$ 为 1e15 / 3.35e12 = 298,和 TPU 的结果相当接近。

第 1 部分到这里结束!关于第 2 部分(真实 TPU 如何处理 FLOPs 和通信),请点这里

杂项

*这项工作完成于 Google DeepMind,作者现就职于 MatX。

引用

如果你要在学术场景中引用这项工作,请使用:

Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.

或者使用下面的 BibTeX 条目:

@article{scaling-book,
  title = {How to Scale Your Model},
  author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
  and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
  publisher = {Google DeepMind},
  howpublished = {Online},
  note = {Retrieved from https://jax-ml.github.io/scaling-book/},
  year = {2025}
}
第 2 章

如何理解 TPU

你可能也会喜欢阅读关于 NVIDIA GPU 的新 [第 12 节](../gpus)!

什么是 TPU?

TPU 本质上是一个专门擅长矩阵乘法的计算核心(称为 TensorCore),外加一组高速内存堆栈(称为高带宽内存,HBM) 下图展示了它的结构:

<b>图:</b>TPU 芯片的基本组成。TensorCore 是左侧的灰色方框,内部包含矩阵乘法单元(MXU)、向量单元(VPU)和向量内存(VMEM)。
图:TPU 芯片的基本组成。TensorCore 是左侧的灰色方框,内部包含矩阵乘法单元(MXU)、向量单元(VPU)和向量内存(VMEM)。

你可以把 TensorCore 理解为一台非常擅长做矩阵乘法的机器,但它还有几个值得注意的功能。TensorCore 有三个关键单元:

TPU 在矩阵乘法上非常、非常快。 这基本上就是它最擅长的事情,而且做得极好。TPU v5p 是目前最强大的 TPU 之一,每个 core 可达到 2.5e14 bf16 FLOPs / second / core,或每个 chip 达到 5e14 bf16 FLOPs / sec / chip。一个包含 8960 个芯片的 pod 每秒可达到 4 exaflops。这个规模非常惊人,它属于世界上最强大的超级计算机之列。而 Google 拥有很多这样的系统。TPU,尤其是其中的脉动阵列,之所以能成为如此强大的硬件加速器,是因为矩阵乘法是少数几种以 $O(n^2)$ 字节实现 $O(n^3)$ 计算量的算法之一。这使得普通 ALU 很容易受计算能力限制,而不是受内存带宽限制。

上面的图还包含了其他一些组件,例如 SMEM 和 scalar unit,它们用于控制流处理,并会在附录 A中简要讨论,但并不是理解 TPU 的关键。相比之下,HBM 既重要也相对简单:

一般来说,TPU 上的所有操作都会被流水化并彼此重叠。 要执行一次矩阵乘法 $X \cdot A \to Y$,TPU 需要先把矩阵 $A$ 和 $X$ 的若干块从 HBM 复制到 VMEM,再将它们加载进 MXU,后者会对 8x128(用于 $X$)和 128x128(用于 $A$)的块执行乘法,然后再把结果逐块复制回 HBM。为了高效完成这一过程,矩阵乘法会被流水化,从而使往返 VMEM 的复制与 MXU 的工作重叠进行。这样 MXU 就可以持续工作,而不必等待内存传输,从而让矩阵乘法保持计算受限,而不是内存带宽受限。

下面是一个从 HBM 执行逐元素乘法的示例:

<b>图:</b>动画展示了 TPU 上执行逐点乘法时,字节如何从 HBM 被加载。注意数据会以块的形式从内存中流出,部分结果也会以流水方式回写,而无需等待整个数组完全物化。
图:动画展示了 TPU 上执行逐点乘法时,字节如何从 HBM 被加载。注意数据会以块的形式从内存中流出,部分结果也会以流水方式回写,而无需等待整个数组完全物化。

矩阵乘法看起来几乎是一样的,只不过它会把数据送入 MXU 而不是 VPU/Vector unit,而且加载和存储的顺序也会不同,因为同一个权重块会被多个激活块重复使用。你可以看到数据块流入 VMEM,再进入 VREG(向量寄存器),随后进入 Vector Unit,最后再返回 VMEM 和 HBM。正如我们马上会看到的,如果从 HBM 到 VMEM 的加载速度慢于 Vector Unit(或 MXU)的 FLOPs 处理速度,我们就会变成 “bandwidth bound”,因为 VPU 或 MXU 会因为缺少工作而被饿住。

**关键结论:** TPU 非常简单。它们把权重从 HBM 加载到 VMEM,再从 VMEM 加载到脉动阵列,而后者每秒可以执行大约 200 万亿次乘加。HBM $\leftrightarrow$ VMEM 和 VMEM $\leftrightarrow$ 脉动阵列之间的带宽,为 TPU 能高效执行哪些计算设定了根本限制。

VMEM 与算术强度: VMEM 比 HBM 小得多,但与 MXU 之间的带宽高得多。正如我们在第 1 节中看到的,这意味着如果某个算法能把所有输入/输出都装进 VMEM,它就不太容易遇到通信瓶颈。这在算术强度较差的计算中尤其有帮助:VMEM 带宽大约是 HBM 带宽的 22 倍,这意味着一个从 VMEM 读写的 MXU 操作,只需要 10-20 的算术强度就能达到峰值 FLOPs 利用率。这意味着如果我们能把权重放进 VMEM 而不是 HBM,我们的矩阵乘法在更小的 batch size 下也能变成 FLOPs 受限。也意味着那些从根本上具有较低算术强度的算法仍然可以高效,只是 VMEM 太小,这往往会成为挑战。我们有时会讨论 VMEM prefetching,即提前把权重加载进 VMEM,以便在执行矩阵乘法时隐藏加载开销。比如在普通 Transformer 中,我们有时可以在 attention 阶段把较大的前馈层权重加载到 VMEM 中;如果当前工作负载受内存带宽限制,这就能隐藏掉权重加载的代价。这要求权重足够小,或者切分得足够细,以便单层权重能够装入 VMEM,且还有额外空间。

assets/img/tpu-bandwidth.png

一个 TPU 芯片通常(但不总是)由两个共享内存的 TPU core 组成,因此可以被视为一个 FLOPs 翻倍的大型加速器(称为 “megacore” 配置)。从 TPU v4 开始就是如此。更早的 TPU 芯片拥有彼此独立的内存,因此被视为两个独立的加速器(TPU v3 及更早版本)。像 TPU v5e 这样的推理优化芯片则每个芯片只有一个 TPU core。

assets/img/cores.png

芯片会以每 4 个组成一个 “tray” 的方式排列,并通过 PCIe network 连接到 CPU host。 这是大多数读者最熟悉的形式:通过 Colab 或单个 TPU-VM 暴露出 4 个芯片(8 个 core,不过通常按 4 个逻辑 megacore 对待)。对于像 TPU v5e 这样的推理芯片,每个 host 有 2 个 tray,而不是 1 个,但每个芯片也只有 1 个 core,因此总共得到 8 个芯片 = 8 个 core。在 Cloud TPU VM 上,每个 tray 会作为独立 VM 的一部分暴露出来,因此可见的 core 数又回到了 4 个。

assets/img/pcie.png

PCIe 带宽是有限的: 与 HBM $\leftrightarrow$ VMEM 链路类似,CPU $\leftrightarrow$ HBM 的 PCIe 连接也具有特定带宽,它限制了你从 host memory 向 HBM 加载数据或反向卸载数据的速度。以 TPU v4 为例,PCIe 带宽双向各为 16GB / second,因此比 HBM 慢接近 100 倍。我们确实可以把数据加载到 host(CPU)RAM 或从中卸载出来,但速度并不快。

TPU 网络

在 Pod 中,芯片之间通过 ICI 网络相连。 在较早代际(TPU v2 和 TPU v3)、推理芯片(例如 TPU v5e)以及 Trillium(TPU v6e)中,ICI(“inter-chip interconnects”)会连接最近的 4 个邻居(并通过边缘链路形成 2D 环面)。TPU v4 和 TPU v5p 则连接最近的 6 个邻居(形成 3D 环面)。请注意,这些连接不会经过它们的 host,而是芯片之间的直接链路。

assets/img/ici-wraparound.png

环面结构会把任意两个节点之间的最大距离从 $N$ 降低到 $N / 2$,从而让通信更快。TPU 还有一种 “twisted torus” 配置,它会以类似莫比乌斯带的拓扑方式对环面进行缠绕,从而进一步降低节点间的平均距离。

通过 ICI 连接的 TPU pod 可以非常大: TPU v4 的最大 pod 大小(称为 superpod)是 16x16x16,而 TPU v5p 是 16x20x28。这些大型 pod 由可重构的 4x4x4 芯片立方体组成,立方体之间通过optical wraparound linksOptical switch 本质上只是一个具有相同 ICI 带宽的可重构连接。它的作用只是让我们在保留 wraparound link 的同时连接多个立方体。相连,从而可以重构出非常大的拓扑。

assets/img/tpu-rack.png

也可以申请较小的拓扑(例如 2x2x12x2x2),不过它们没有 wraparound。这是一个重要注意点,因为这通常会让大多数通信的时间翻倍。任何完整立方体的整数倍(例如 4x4x44x4x8)都会获得由 optical switches 提供的 wraparound。注意,2x2x4 不会有任何 wraparound,因为这些 wraparound 由 optical switches 提供,而它们只在完整立方体上可用。不过,TPU v5e 8x16 在较长的轴上拥有 wraparound,因为它不使用可重构 optical networking。

assets/img/subslices.png

TPU v5e 和 Trillium pod 由单个 16x16 2D 环面组成,只要某个轴的长度为 16,该轴上就具有 wraparound(这意味着 8x16 会在长轴上拥有 wraparound)。TPU v5e 和 v6e(Trillium)不能扩展到超出 16x16 的环面,但不同 pod 之间仍然可以通过标准数据中心网络(DCN)互相通信,它负责把 TPU host 彼此连接起来。再次强调,较小的拓扑也可以申请,只是维度 $<16$ 的轴不会有 wrap。

assets/img/more-subslices.png

这种最近邻连接是 TPU 与 GPU 之间的一个关键差异。 GPU 通过分层交换机结构连接,这种结构近似于让每个 GPU 都能与其他 GPU 建立点对点连接,而不是像 TPU 那样依赖局部连接。通常,同一节点内的 GPU(H100 为 8 个,B200 NVL72 最多可达 72 个)是直接相连的,而更大的拓扑则要求每个 GPU 之间经过 O(log(N)) 次跳转。一方面,这意味着 GPU 可以在较少跳数内发送任意数据。另一方面,TPU 的成本要低得多(因为 NVLink 交换机很昂贵)、布线也更简单,而且由于每个设备的链路数和每个设备的带宽都是常数,它们可以扩展到更大的拓扑。更多内容请见这里

ICI 相对于 DCN 非常快,但仍然慢于 HBM 带宽。 例如,TPU v5p 具有:

这意味着,当我们把模型拆分到多个芯片上时,必须小心避免让更慢的跨设备通信成为 MXU 的瓶颈。

多 slice 训练: 一组通过 ICI 互连的 TPU 被称为一个 slice。不同的 slice 之间可以通过 DCN 互连,例如把不同 pod 上的 slice 连接起来。由于 DCN 比 ICI 慢得多,我们应尽量减少计算等待 DCN 数据的时间。DCN 是 host-to-host 的,因此要通过 DCN 在 TPU 之间传输 buffer,我们首先需要通过 PCIe 把数据传到 host,然后通过网络发出,再从目标 host 网络接收,最后再通过 PCIe 进入 HBM。

关键结论

TPU 规格

下面列出了一些芯片的具体参数:

Model Pod size Host size HBM capacity/chip HBM BW/chip (bytes/s) FLOPs/s/chip (bf16) FLOPs/s/chip (int8)
TPU v3 32x32 4x2 32GB 9.0e11 1.4e14 1.4e14
TPU v4p 16x16x16 2x2x1 32GB 1.2e12 2.75e14 2.75e14
TPU v5p 16x20x28 2x2x1 96GB 2.8e12 4.59e14 9.18e14
TPU v5e 16x16 4x2 16GB 8.1e11 1.97e14 3.94e14
TPU v6e 16x16 4x2 32GB 1.6e12 9.20e14 1.84e15

Host size 指的是连接到单个 host 的 TPU 拓扑(例如 TPU v5e 中,一个 CPU host 连接 8 个 TPU,形成 4x2 拓扑)。下面是互连参数:

Model ICI BW/link (one-way, bytes/s) ICI BW/link (bidi, bytes/s)
TPU v3 1.0e11 2.0e11
TPU v4p 4.5e10 9.0e10
TPU v5p 9.0e10 1.8e11
TPU v5e 4.5e10 9.0e10
TPU v6e 9.0e10 1.8e11

我们同时给出 one-way(unidirectional)和 bidi(bidirectional)带宽,因为单向带宽更贴近硬件本身,而双向带宽更常出现在涉及完整 ring 的公式中。这里所说的 bidi(bidirectional)带宽,是指一条链路在两个方向上总共可发送的字节数;或者等价地说,在我们能够高效使用两条链路的前提下,它也代表单个 TPU 沿某个特定轴的总出向字节数。当我们拥有一个真正工作的 ring,也就是该轴上存在 wraparound 连接时,这个定义才成立。这种情况会出现在推理芯片的完整 16 轴上,或训练芯片(v*p)中某个轴长度为 4 的倍数时。我们更倾向于使用双向带宽,因为它在双向通信的计算中非常常见。

PCIe 带宽通常约为每个 TPU 1.6e10 bytes / second(TPU v6e 为 3.2e10),而 DCN 带宽通常约为每个 TPU 6.25e9 bytes / second(TPU v6e 为 12.5e9,TPU v5e 为 3.125e9)。

习题示例

这些数字本身有点枯燥,但它们能够帮助你对模型性能做出基本的 roofline 估算。下面我们通过几道题来说明这为什么有用。你会在第 3 部分里看到更多例子。

问题 1【LLM 延迟下界】: 假设你想从一个 200B 参数、bf16 精度、切分在 32 个 TPU v4p 上的模型中进行采样。把所有参数从 HBM 加载到脉动阵列需要多长时间?提示:使用上面的数字。

点击查看答案。 **答案:** 我们需要加载 `sizeof(bf16) * 200e9 = 400e9` 字节,分布在 32 个芯片上,也就是每个芯片 12.5e9 字节,而每个芯片的 HBM 带宽为 1.23e12。因此加载大约需要 10ms。 这很有意思,因为*这给出了模型采样延迟的一个合理下界*。每一步采样都需要从 HBM 中加载全部参数,因此它不可能低于 10 ms。在小 batch size 下,这在实践中已经接近可实现。

问题 2【TPU 细节】: 考虑一个完整的 TPU v5e pod。总共有多少个 CPU host?有多少个 TPU TensorCore?整个 pod 的总 FLOPs/s 是多少?HBM 总量是多少?再对 TPU v5p pod 做同样的计算。

点击查看答案。 **答案:** 对于 TPU v5e,每个 pod 是 `16x16`,每个 host 是一个 4x2 slice,因此共有 `16*16 / 8 = 32` 个 host。对于 TPU v5e,每个 TPU 只有一个 core,因此总共有 256 个 TensorCore。总 FLOPs/s 为 `16*16*2e14 = 5.1e16`(bfloat16)。每个芯片有 16GB HBM,因此总内存为 `256 * 16 = 4TB`。 对于完整的 TPU v5p pod,我们有 `16x20x28` 个芯片,而每个 host 是 2x2x1,因此共有 `(16*20*28) / (2*2) = 2,240` 个 host。对于 TPU v5p,每个 TPU 有两个 TensorCore,因此共有 `8960 * 2 = 17,920` 个 core。总 FLOPs/s 为 `8960 * 4.5e14 = 4e18`(bfloat16)。每个芯片有 96GB HBM,因此总内存为 `8960 * 96 = 860TB`。

问题 3【PCIe 运算强度】: 假设我们不得不把一个大型权重矩阵 $A$(类型为 $\text{bfloat16}[D, F]$)和一批激活 $x$(类型为 $\text{bfloat16}[B, D]$)存放在 host DRAM 中,并希望对它们执行矩阵乘法。该操作运行在单个 host 上,我们使用的是连接到它的一块 TPU v6e 芯片。你可以假设 $B \ll D$,并且 $F = 4D$(我们会在后续章节中解释为什么这是合理假设)。为了在 PCIe 上保持 FLOPs 受限,我们所需的最小 batch size $B$ 是多少?假设 PCIe 带宽为 1.5e10 bytes / second。

点击查看答案。 **答案:** 我们需要执行 $2BDF$ 次浮点运算,而每个芯片每秒可以执行 `9.2e14` 次浮点运算。因此计算需要 $2BDF / 9.2e14$ 秒。我们需要从 DRAM 加载 $2DF + 2BD$ 字节,并回写 $2BF$ 字节。由于我们受 PCIe 传输速度限制,因此把数据传入和传出 TPU 需要 $2 \cdot (BD + DF + BF) / 1.5e10$ 秒。既然我们希望在能够把权重加载与计算重叠的前提下,计算时间长于权重加载时间,那么我们希望满足 $2BDF / 9.2e14 > 2 \cdot (BD + DF + BF) / 1.5e10$。利用 $B \ll D$ 且 $F = 4D$ 的假设,可化简为 $$\frac{8BD^2}{9.2 \times 10^{14}} > \frac{8D^2}{1.5 \times 10^{10}}$$ 也就是 $$B > \frac{9.2 \times 10^{14}}{1.5 \times 10^{10}} \simeq 61{,}000$$

问题 4【一般矩阵乘法延迟】: 假设我们要将一个权重矩阵 int8[16384, 4096] 与一个大小为 int8[B, 4096] 的激活矩阵相乘,其中 B 是某个未知的 batch size。先假设我们运行在 1 个 TPU v5e 上。

  1. 这个乘法作为 B 的函数需要多长时间?提示:可以分别计算从 HBM 加载数组所需的时间,以及实际执行乘法所需的时间。到底哪一项是瓶颈?
  2. 如果我们希望从 VMEM 中运行这个操作,会需要多长时间?
点击查看答案。 **答案:** (1) 我们需要执行的浮点运算数为 $2 \cdot 4096 \cdot 16384 \cdot B = 1.3 \times 10^{8} \cdot B$。因此 $T_{\text{math}} = (1.3 \times 10^{8} \cdot B) / 3.94 \times 10^{14}$ 秒。我们需要把 $16384 \cdot 4096 + 4096 \cdot B$ 字节从 HBM 加载到 VMEM,并将 $16384 \cdot B$ 字节从 VMEM 写回 HBM。因此,$T_{\text{comms}} = (6.7 \times 10^{7} + 2 \times 10^{4} \cdot B) / 8.1 \times 10^{11}$ 秒。假设通信和计算可以尽可能重叠,则整个乘法大约需要 $$\max\{T_{\text{math}}, T_{\text{comms}}\} = \max\left\{ \frac{6.7 \times 10^{7} + 2 \times 10^{4} \cdot B}{8.1 \times 10^{11}}, \frac{1.3 \times 10^{8} \cdot B}{3.94 \times 10^{14}} \right\}$$ 当 $\frac{6.7 \times 10^{7} + 2 \times 10^{4} \cdot B}{8.1 \times 10^{11}} < \frac{1.3 \times 10^{8} \cdot B}{3.94 \times 10^{14}}$ 时,我们就是 FLOPs 受限。等价地,$B > 271$。这比我们在[第 1 节](../roofline)中推导出的 240 略大,因为这里考虑了 $D$ 和 $F$ 的完整影响。 (2) 如果改为从 VMEM 加载,我们把 VMEM 到 MXU 的带宽视作 HBM $\leftrightarrow$ VMEM 带宽的 22 倍。这样数据加载公式中的分母就会从 8.1e11 变成 1.78e13,于是得到 $B > 11$。注意,在实践中我们无法把全部 VMEM 带宽都用来加载 $W$,因此实际情况通常会更接近 20。

问题 5【ICI 带宽】: 假设我们有一个 TPU v5e 4x4 slice。现在要把一个类型为 bfloat16[8, 128, 8192] 的数组从 TPU{0,0} 发送到 TPU{3, 3}。假设 TPU v5e 的每跳延迟为 $1\mu s$。

  1. 第一个字节最早会在多久后到达目的地?
  2. 整个传输总共需要多长时间?
点击查看答案。 **答案:** TPU v5e 具有 2D 连接。由于这里只有一个 `4x4` slice(且没有任何长度为 16 的轴),所以没有 wraparound 连接。因此,目标芯片有两个端口可以接收数据,源芯片也同样有两个端口可以发送数据。需要传输的数据量为 `2 * 8 * 128 * 8192 = 1.7e7` 字节。我们可以同时通过两个端口传输(即一半向右发送,一半向下发送),因此总传输带宽为 `2 * 4.5e10 = 9e10` bytes per second,这意味着整个数组的传输时间约为 `1.7e7 / 9e10 = 188us`(假设我们受带宽限制)。在一个 `4x4` slice 中,芯片 $(0, 0)$ 与 $(3, 3)$ 之间需要经过 6 跳,因为长度小于 16 的轴上没有 wraparound 链路。若每跳延迟约为 $1\mu s$,则第一个字节大约会在 `6us` 后到达,而整个传输将耗时 `188us`。

问题 6【综合题,较难】: 假设你有一个大矩阵 Aint8[128 * 1024, 128 * 1024],它被均匀切分在一个 TPU v5e 4x4 slice 上,但每个芯片上的对应分片都被卸载到 host DRAM。现在你想把整个数组复制到 TPU{0, 0},并让它与一个向量 bf16[8, 128 * 1024] 相乘。这需要多长时间?提示:使用上面的数字。

点击查看答案。 **答案:** 先来梳理我们必须执行的操作。这个数组大约有 16GB。根据上表,TPU v5e host 具有 4x2 拓扑,因此一个 4x4 对应 2 个 host。于是,因为数组是均匀切分的,每个 host 实际上持有整个数组的一半,也就是 8GB。我们需要把这些块全都复制到 TPU{0,0},因此有两个选择: 1. 通过 DCN 复制,然后再通过 PCIe 把整个未切分数组加载进 HBM。 2. 先把各自切分后的数组加载到对应 TPU 上,然后通过 ICI 执行 gather,最后在 TPU{0,0} 上完成矩阵乘法。 很明显,方案 (2) 更好。与 ICI 相比,DCN 太慢了;而且相比只使用 host 0 上的少数几条 PCIe 链路,我们更希望通过更多 PCIe 链路来加载这个大数组。下图展示了系统的一部分。如上文所述,请注意 TPU 通过 ICI 与邻居相连(即使跨 host 也是如此),所有 TPU 都通过 PCIe 与各自 host CPU 相连,而 host 之间则通过 DCN 相连。
实际上每个芯片都有一条独立的 PCIe 链路连接到其 host,不过为了清晰起见,这里只画出了一条。
实际上每个芯片都有一条独立的 PCIe 链路连接到其 host,不过为了清晰起见,这里只画出了一条。
下面我们逐步估算每一部分所需的时间: 1. **PCIe 加载:** 我们通过 16 条 PCIe 链路加载共 16GB 的数据,每条链路的带宽为 `1.5e10` bytes/second。因此这一步大约需要 66ms。 2. **ICI 复制:** 此时每个 TPU 都持有 `16GB / 16 = 1GB` 的数组。我们的 ICI 带宽是每条链路双向 `9e10` bytes/second,你会注意到,从上图可以看出,对于 TPU{0,0} 来说,在这个拓扑中 4 条 ICI 链路里实际只有 2 条在被使用。由于 TPU{0,0} 需要沿 2 个轴接收总共 15GB 的数据,且每条链路速度为 `4.5e10` bytes/s/link,因此我们可以给出一个时间下界:`15e9 / (4.5e10 * 2) = 167ms`。实践中这大概难以达到,因为负载分布非常不均匀,但应该不会差超过 2 倍。正如你将在第 3 节看到的,执行一次完整的 AllGather 也大致需要 `16e9 / (4.5e10 * 2)`,因此这个估计已经接近最优。 3. **HBM $\rightarrow$ MXU 加载:** 为了执行最终的矩阵乘法,我们需要把这 16e9 字节以及 bf16[8, 128 \* 1024] 数组(另外 2MB 左右,因此可忽略)通过 HBM 带宽加载到 MXU 中,这一步需要 `16e9 / 8.1e11 = 19ms`。 4. **FLOPs:** 我们总共需要执行 $$2 \cdot 8 \cdot 128 \cdot 1024 \cdot 128 \cdot 1024 = 2.7 \times 10^{11}$$ 次 FLOPs,而 TPU v5e 可以达到 `1.97e14` bf16 FLOPs/s,因此计算时间约为 1.3ms。 总时间的一个上界是以上时间之和;但由于 TPU 通常可以将这些操作重叠执行,我们更适合把它看成一个由最慢阶段决定的流水问题。如果这个假设成立,那么答案至少是 167ms;考虑到重叠并不完美,更可能接近 200ms。

第 2 部分到这里就结束了!如果想继续阅读关于分区与跨 TPU 通信的第 3 部分,请[点击这里](../sharding)。

附录

附录 A:更多 TPU 内部机制

这里我们将更深入地讨论 TPU 的内部工作机制。除非特别说明,以下规格都以 TPU v5p 为例。

VPU

VPU 是 TPU 的向量算术核心。它由一个二维 SIMD 向量机器(即 VPU)和一组称为 VREGs 的向量寄存器组成。前者负责执行诸如 vadd(向量加法)或 vmax(逐元素最大值)之类的逐元素算术操作,后者则为 VPU 和 MXU 保存数据。

VREGs: 每个 TPU v5p core 有 64 个 32-bit VREG(TPU v4 上是 32 个),因此每个 core 的 VREG 内存总量约为 64 * 8 * 128 * 4 = 256kB(整个 chip 则是它的 2 倍,因为有两个 core)。TPU v5p 每个周期可以从 VMEM 加载 3 个寄存器,并向 VMEM 写回 1 个寄存器。

VPU: VPU 是一个形状为 (8, 128) 的二维向量算术单元,其中维度 128 被称为 lane axis,维度 8 被称为 sublane axis。在 v5 中,每个 (lane, sublane) 对上都包含 4 个彼此独立的标准浮点 ALU。VPU 会在每个 ALU 上以一个周期执行大多数算术指令(例如 vadd,也就是向量加法),其延迟为 2 个周期,因此例如在 v5 上,你可以在每个周期里从 VREG 中取出两组 f32 数值并做 4 次加法。一个典型的 VPU 指令可能看起来像 {v2 = vadd.8x128.f32 v0, v1},其中 v0 和 v1 是输入 VREG,v2 是输出 VREG。

所有 lane 和 sublane 都会在每个周期以纯 SIMD 的方式执行同一个程序,但每个 ALU 可以执行不同的操作。因此,例如我们可以在一个周期内同时执行 1 条 vadd 和 1 条 vsub,它们各自作用于两组完整的 VREG,并把输出写入第三个寄存器。

小测验【计算 VPU 吞吐量】: 根据以上信息,计算一块 TPU v5p 能执行多少向量 FLOPs/s。TPU v5p 的主频大约为 1.75GHz。

点击查看答案。 *答案*:每个周期中,每个 core 可以在 `8 * 128` 个 ALU 上执行 4 条向量指令。因此整个芯片每周期可达到 `8 * 128 * 4 * 2` FLOPs/cycle,即 `8 * 128 * 4 * 2 * 1.75e9 = 1.4e13 FLOPs/s`。注意,这相比 MXU 大约 `2e14` FLOPs/s 的吞吐量要小得多(大约低 10 倍)。

归约: 一般来说,跨 sublane 维度进行通信或归约,要比跨 lane 维度容易得多。例如,VPU 支持一种 intra-lane shuffle 操作,它可以在长度为 8 的那个轴上以大约一个周期的代价执行滚动。这可以用来高效地沿 sublane 维度做归约(只需按 4、2 和 1 做 shuffle,并执行 3 轮逐元素求和)。

跨 lane 的归约则要困难得多,需要借助一个独立的硬件单元,称为 XLU 或 “cross lane unit”,它速度较慢,代价也相当高。

与 GPU 的对比: 对于熟悉 NVIDIA GPU 的读者来说,VPU 中的每个 ALU 都可以类比为一个 CUDA core,而单个 VPU lane 则类似于一个 “Warp Scheduler”,即通常由 32 个 CUDA Cores 组成、负责执行 SIMD 算术的那组单元。lane 内部的归约相对容易,但如果需要跨 lane,就至少必须经过 VMEM/XLU/SMEM,这会慢得多。更多内容请见GPU 章节

Scalar Core

Scalar core 是 TPU 的控制单元。它负责获取和分发所有指令,执行从 HBM 到 VMEM 的传输,也可以被编程来处理标量元数据工作。由于 scalar core 是单线程的,一个副作用就是 TPU 的每个 core 每个周期只能发起一个 DMA 请求。

为了帮助理解这个比例:一个 scalar core 控制着一个 VPU(包含 4096 个 ALU)、4 个 MXU、2 个 XLU,以及多个 DMA 引擎。每单位计算所对应的控制逻辑极度偏斜,这既是硬件高效率的来源之一,也限制了系统进行复杂数据依赖向量化的能力。

附录 B:脉动阵列是如何工作的?

TPU 的 MXU 核心是一个 128x128 的脉动阵列(在 TPU v6e 上为 256x256)。在完全饱和时,该脉动阵列每 8 个时钟周期可以执行一次 bfloat16[8,128] @ bf16[128,128] -> f32[8,128]如果你不熟悉这种记法,它表示:将一个元素类型为 bfloat16 的 8x128 矩阵,与一个元素类型为 bfloat16 的 128x128 矩阵相乘,并把结果存入一个元素类型为 float32 的 8x128 矩阵中。 乘法。

下面是一个简化动画,展示了一组权重(蓝色)与一组激活(绿色)相乘的过程。你会注意到,权重(RHS)首先以对角线方式部分加载,随后激活也以对角线方式送入。下方每一帧中,我们都会把所有重叠的绿色单元和蓝色单元相乘,将结果与从上方传下来的残余值相加,然后再把结果继续向下传递一个单元。

assets/img/systolic-array.gif

下面是一个更通用的动画版本,展示输出如何从计算中以流式方式写出:

assets/img/systolic-array2.gif

下面这张图展示了如何在多个 RHS 和 LHS 数组之间进行流水化:

assets/img/systolic-array-pipelining.png

在最开始,随着权重(RHS)和激活(LHS)被加载,会存在一个初始的流水线气泡。这个初始气泡结束之后,就可以持续加载新的输入和权重,而不会再产生额外气泡。

下面是一个不太精致的动画,展示了 bf16[2, 3] x bf16[3, 3] 的矩阵乘法。你可以把它想象成:一个 2x3 的权重矩阵,与一个 batch 为 1、大小为 3 的输入激活执行矩阵乘法。这个动画相较于前面的示意图做了旋转,输入会向右流动而不是向下,但你仍然可以大致看出它的结构。

assets/img/systolic-array-bad.gif

我们可以对这一过程进行高效流水化,从而完成大型矩阵乘法,而不会引入过大的流水线气泡。即便如此,依然有一点很重要:我们的矩阵形状必须大于 MXU 的边长,通常也就是 128x128。某些 TPU(从 TPU v3 开始)拥有多个 MXU,TPU v3 有 2 个,而 TPU v4/5 有 4 个,因此我们还需要确保分块维度大于 128 * MXU 数量。这里有一个不错的动画可以辅助理解。

Trillium(TPU v6e)拥有一个 256x256 的脉动阵列,这意味着它每个周期可以执行 4 倍 FLOPs。这也意味着,为了把 MXU 充分利用起来,你的张量维度也需要扩大为原来的两倍。

这篇博文还提供了另一个非常出色的固定权重矩阵脉动阵列乘法动画示例。

第 3 章

分片矩阵以及如何做矩阵乘法

划分记号与集体通信操作

当我们在一万张 TPU 或 GPU 上训练一个 LLM 时,从抽象上看,做的仍然和在一张卡上训练时是同一种计算。不同之处在于,数组放不进单个 TPU/GPU 的 HBM,所以必须把它们拆开。还值得注意的是,我们有时也会为了速度而选择并行化。即使模型能放进更少的芯片里,扩展到更多芯片也只是意味着我们能得到更高的 FLOPs/s。例如在推理时,我们有时能放进更小的拓扑,但仍会扩到更大的拓扑以降低延迟;类似地,在训练时我们也常常扩到更多芯片,以减少 step 时间。 我们把这件事称为“分片”或“划分”数组。扩展的艺术,就在于找到一种对模型分片的方式,使计算依然高效。

下面是一个二维数组 A 在 4 张 TPU 上分片的例子:

<b>图:</b> 一个形状为 <b>A</b>[I, J] 的数组被分片到 4 个设备上。两个维度都分别沿 2 个设备均匀分片,因此分片方式为 <b>A</b>[I<sub>X</sub>, J<sub>Y</sub>]。每张 TPU 持有总内存的 1/4。
图: 一个形状为 A[I, J] 的数组被分片到 4 个设备上。两个维度都分别沿 2 个设备均匀分片,因此分片方式为 A[IX, JY]。每张 TPU 持有总内存的 1/4。

注意,分片后的数组仍然具有和未分片数组相同的全局逻辑形状,例如 (4, 128);但它还有一个设备本地形状,例如 (2, 64),这才决定了每张 TPU 实际持有多少字节(上图中每张 TPU 持有整个数组的 1/4)。下面我们把这一点推广到任意数组。

统一的分片记号

我们使用一种“命名轴记号(named-axis notation)”的变体,来描述张量是如何按块分布到设备上的:我们假设存在一个二维或三维的设备网格,称为 device mesh,其中每个轴都有名字,例如 XYZ。然后,我们通过说明数组的每个命名维度如何沿物理 mesh 轴划分,来描述矩阵数据在 device mesh 上的布局。这个映射就叫做 sharding(分片方式)

例子(上图):对于上面的图,我们有:

结合起来看,我们知道单个设备所持有的 shard 的本地形状是 $(\lvert I\rvert / 2, \lvert J\rvert / 2)$,其中 $\lvert I\rvert$ 是 A 的第一维大小,$\lvert J\rvert$ 是 A 的第二维大小。

思考题 [沿 1 个轴做二维分片]: 考虑数组 fp32[1024, 4096],分片方式为 $A[I_{XY}, J]$,mesh 为 {'X': 8, 'Y': 2}。每个设备持有多少数据?若在 H100 上从 HBM 读取该数组(假设每卡内存带宽为 3.4e12),需要多长时间?

点击查看答案 $A[I_{XY}, J]$ 表示把第一维(I)同时沿硬件轴 X 和 Y 分片。在这个例子里,本地形状为 $(\lvert I\rvert /(\lvert X\rvert \cdot \lvert Y\rvert), \lvert J\rvert)$。给定全局形状为 `fp32[1024, 4096]`,所以本地形状为 `fp32[64, 4096]`。 每张 GPU 上的数据量为 `4 * 64 * 4096 = 1MiB`,因此读取时间大约是 `1e6 / 3.4e12 = 294ns`。不过由于数组太小,实际时间通常会明显更长,因为各种固定开销会占主导。

把这些分片可视化: 下面通过一个二维数组分布在 4 个设备上的例子,来直观看看这些分片方式:

assets/img/sharding-colored1.png

完全复制(fully replicated)的矩阵写作 $A[I, J]$,不带任何分片标注。这表示每个设备都持有整个矩阵的一份完整拷贝。

assets/img/sharding-colored2.png

我们可以用下标 mesh 轴来表示某个逻辑维度沿某个 mesh 轴被划分。例如,$A[I_X, J]$ 表示逻辑轴 I 沿 mesh 维 X 分片,而 J 没有分片,因此数据块会沿 Y 轴保持部分复制

assets/img/sharding-colored3.png

$A[I_X, J_Y]$ 表示逻辑轴 I 沿 mesh 轴 X 分片,而逻辑轴 J 沿 mesh 轴 Y 分片。

assets/img/sharding-colored4.png

下图展示了其余几种可能:

assets/img/sharding-colored5.png

其中,$A[I_{XY}, J]$ 表示把 XY 这两个 mesh 轴视为一个更大的“扁平化”维度,并将命名轴 I 沿所有设备共同分片。多个 mesh 轴下标的顺序是有意义的,它决定了在网格上遍历分片的顺序。

assets/img/sharding-colored6.png

最后,请注意:不能让多个命名轴沿同一个 mesh 维度分片。例如,$A[I_X, J_X]$ 是没有意义、也被禁止的分片方式。一旦某个 mesh 维度已经用于分片数组的某个维度,从某种意义上说,它就“用掉了”。

思考题:A 是一个形状为 int8[128, 2048] 的数组,分片方式为 $A[I_{XY}, J]$,mesh 为 Mesh({'X': 2, 'Y': 8, 'Z': 2})(总共 32 个设备)。A 在每个设备上占多少内存?在所有设备上总共占多少内存?

点击查看答案 **答案:** 数组 **A** 沿 X 和 Y 分片、沿 Z 复制,因此每个设备上的本地形状为 `int8[128 / (2 * 8), 2048] = int8[8, 2048]`,大小是 `8 * 2048 = 16,384` 字节。由于它沿 Z 复制,所以虽然在每个 Z 平面内它在 X 和 Y 上被完全分片,但整个系统里存在 2 份完整拷贝(每个 Z 平面一份)。因此所有设备上的总大小为:原数组大小 × Z 方向复制数 = `128 * 2048 * 2 = 512 KiB`。另一种验证方式是:`32` 个设备 × `16,384` 字节/设备 = `512 KiB`。

代码里如何描述?

前面我们故意先不谈代码;现在正好可以稍微预览一下。JAX 使用一种“命名分片”语法,与上面描述的抽象语法非常接近。我们会在第 10 章更详细地讨论它;这里先快速看一个例子。你也可以在这个 Google Colab 里动手试,并 profile 一下 JAX 是如何处理不同分片方式的。下面这段代码做了三件事:

  1. 创建一个 jax.Mesh,把 8 张 TPU 映射为一个 4x2 的网格,并将两个轴命名为 'X''Y'
  2. 创建矩阵 A 和 B,其中 A 沿两个维度都做分片,B 沿输出维度做分片。
  3. 编译并执行一个简单的矩阵乘法,返回分片后的数组。
import jax
import jax.numpy as jnp

# 创建 mesh。这里我们运行在 TPU v2-8 的 4x2 slice 上,轴名为 'X' 和 'Y'。
assert len(jax.devices()) == 8
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))

# 一个帮助定义分片的小工具函数。PartitionSpec 就是我们的分片描述
#(把逻辑轴映射到 mesh 轴名)。
def P(*args):
  return jax.NamedSharding(mesh, jax.sharding.PartitionSpec(*args))

# A 和 B 都沿非收缩维做分片,A 还沿收缩维分片。
A = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=P('X', 'Y'))
B = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=P(None, 'Y'))

# 我们可以直接对这些分片数组做 matmul。out_shardings 指定输出应采用
# 什么分片。JAX/XLA 会自动处理其余分片传播与通信。
y = jax.jit(lambda A, B: jnp.einsum('BD,DF->BF', A, B), out_shardings=P('X', 'Y'))(A, B)

JAX 很酷的一点在于:这些数组表现得就像它们没有分片一样。B.shape 会告诉我们全局(逻辑)形状 (2048, 8192);如果想看本地分片情况,就得去看 B.addressable_shards。你可以直接在这些数组上做操作,JAX 会尝试自动推断如何广播、重排它们以完成运算。例如,在上面的例子里,A 的本地形状是 [2, 1024]B 的本地形状是 [2048, 4096]。JAX/XLA 会根据需要自动插入通信,以完成最终的乘法。

在分片数组上做计算

如果你有一个分布在许多设备上的数组,并且想在它上面做数学运算,那么同时对“数据”和“计算”进行分片,会带来哪些额外开销?

显然,这取决于具体的运算。

本节接下来讨论的,就是如何计算分片矩阵的乘法。粗略地说,这意味着:为了把每个块真正乘出来并求和,我们要在网络中移动矩阵的某些块。不同的分片方式会引入不同的通信。 例如,$A[I_X, J] \cdot B[J, K_Y] \to C[I_X, K_Y]$ 可以在没有任何通信的情况下完成,因为收缩维度(contracting dimension,也就是实际被求和的那个维度 J)没有分片。但如果我们想让输出不分片,也就是 $A[I_X, J] \cdot B[J, K_Y] \to C[I, K]$,那就必须把 $A$ 和 $B$ 复制到每个设备,或者把 $C$ 复制到每个设备(使用 AllGather)。这两种选择的通信代价不同,因此我们需要计算其代价并选更便宜的那个。

你可以把它理解成“块矩阵乘法” 理解这一点时,回忆“块矩阵”(block matrix)的概念会很有帮助,也就是“由矩阵组成的矩阵”: $$\begin{equation} \begin{pmatrix} a_{00} & a_{01} & a_{02} & a_{03} \\ a_{10} & a_{11} & a_{12} & a_{13} \\ a_{20} & a_{21} & a_{22} & a_{23} \\ a_{30} & a_{31} & a_{32} & a_{33} \end{pmatrix} = \left( \begin{matrix} \begin{bmatrix} a_{00} & a_{01} \\ a_{10} & a_{11} \end{bmatrix} \\ \begin{bmatrix} a_{20} & a_{21} \\ a_{30} & a_{31} \end{bmatrix} \end{matrix} \begin{matrix} \begin{bmatrix} a_{02} & a_{03} \\ a_{12} & a_{13} \end{bmatrix} \\ \begin{bmatrix} a_{22} & a_{23} \\ a_{32} & a_{33} \end{bmatrix} \end{matrix} \right) = \begin{pmatrix} \mathbf{A_{00}} & \mathbf{A_{01}} \\ \mathbf{A_{10}} & \mathbf{A_{11}} \end{pmatrix} \end{equation}$$ 矩阵乘法有一个很好的性质:如果把乘数写成按块组织的形式,那么乘积也可以按标准规则写成块矩阵乘法: $$\begin{equation} \begin{pmatrix} A_{00} & A_{01} \\ A_{10} & A_{11} \end{pmatrix} \cdot \begin{pmatrix} B_{00} & B_{01} \\ B_{10} & B_{11} \end{pmatrix} = \begin{pmatrix} A_{00}B_{00} + A_{01}B_{10} & A_{00}B_{01} + A_{01}B_{11} \\ A_{10}B_{00} + A_{11}B_{10} & A_{10}B_{01} + A_{11}B_{11} \end{pmatrix} \end{equation}$$ 这意味着,实现分布式矩阵乘法,本质上就是:把这些分片块在网络中移动,在本地对块做矩阵乘法,再把结果求和。**问题只在于:该加什么通信,以及它的代价有多大。**

方便的是,我们可以把所有可能的分片方式大致归纳成 4 种情况;每种情况都有一个对应的通信规则:

  1. 情况 1: 两个输入都没有沿收缩维度分片。我们可以直接对本地 shard 做乘法,无需通信
  2. 情况 2: 其中一个输入沿收缩维度分片。我们通常会先沿收缩维度对这个输入做 AllGather
  3. 情况 3: 两个输入都沿收缩维度分片。我们可以先做本地乘法,再对结果做 AllReduce
  4. 情况 4: 两个输入的某个非收缩维度都沿同一个 mesh 轴分片。这种情况下,必须先对其中一个输入做 AllGather,否则无法继续。

你当然可以把它们当作死记硬背的规则,但理解为什么这些规则成立、以及它们的代价如何,也很有价值。下面我们依次展开。

情况 1:两个乘数都没有分片的收缩维度

引理: 分片矩阵相乘时,只要收缩维度没有分片,并且两个矩阵没有沿同一个轴做冲突的分片,那么计算就是合法的,输出也会自然继承输入的分片方式。例如,下面这个计算完全没问题:

$$\begin{equation*} \mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K_Y] \rightarrow \mathbf{C}[I_X, K_Y] \end{equation*}$$

它完全不需要通信,并且结果张量会沿硬件维度 X 和 Y 分片。可以想想这是为什么:本质上,计算和分片是独立的,因为每个 batch 项都拥有被收缩那个轴的一部分本地块,它可以直接在本地完成乘法和归约。下面这些情形都成立,也都服从同一条规则:

$$\begin{align*} \mathbf{A}[I, J] \cdot \mathbf{B}[J, K] \rightarrow &\ \mathbf{C}[I, K] \\ \mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K] \rightarrow &\ \mathbf{C}[I_X, K]\\ \mathbf{A}[I, J] \cdot \mathbf{B}[J, K_Y] \rightarrow &\ \mathbf{C}[I, K_Y]\\ \mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K_Y] \rightarrow &\ \mathbf{C}[I_X, K_Y] \end{align*}$$

因为 AB 都没有在收缩维度 J 上分片,所以只需要对输入做本地块矩阵乘法,结果就会天然满足所需的输出分片。如果两个乘数的非收缩维度沿同一个轴分片,就不再是这样了(详见后面的“非法分片”情形)。

情况 2:一个乘数有分片的收缩维度

来看这样一种情况:输入 A 沿收缩维度 J 分片,而 B 是完全复制的:

$$\mathbf{A}[I, J_X] \cdot \mathbf{B}[J, K] \rightarrow \mathbf{C}[I, K]$$

这时,我们不能直接把 AB 的本地块拿来相乘,因为计算需要沿 A 的整个收缩维度求和,而这个维度已经被拆到了 X 轴上。通常的做法是:先把 A 的各个 shard 做一次 AllGather,让每个设备都拿到完整副本,然后再与 B 相乘:

$$\textbf{AllGather}_X[I, J_X] \rightarrow \mathbf{A}[I, J]$$ $$\mathbf{A}[I, J] \cdot \mathbf{B}[J, K] \rightarrow \mathbf{C}[I, K]$$

这样,真正的乘法就可以在每个设备上完整地执行。

**要点:** 当矩阵乘法里某个矩阵沿收缩维度分片时,我们通常先对它执行 AllGather,让收缩不再发生在分片轴上,然后再做本地 matmul。

注意,如果 B 没有沿 X 轴分片,另一种做法是:先在本地算出部分 matmul 结果,再对这些部分和做求和(或 AllReduce)。在某些情况下这会更快。见后面的题 4。

什么是 AllGather?

AllGather 是我们要讨论的第一个核心 MPI 通信原语。AllGather 会消除某个轴上的分片,并把分散在各设备上的各个 shard 在该轴上重新收集到每一个设备上。按照上面的记号,它相当于移除一个维度上的下标。例如:

$$\textbf{AllGather}_{XY}(A[I_{XY}, J]) \rightarrow A[I, J]$$

我们不一定要一次移除该维度上的所有下标,例如,$A[I_{XY}, J] \rightarrow A[I_Y, J]$ 也是一次 AllGather,只不过只沿其中一个轴做 gather。

另外,AllGather 也不只用来消除收缩维度上的分片;它也可以用于消除非收缩维度的分片。例如下面这个乘法:

$$A[I_X, J] \cdot B[J, K] \rightarrow C[I, K]$$

我们既可以先对 A 做 AllGather 以消除输入分片,也可以先执行分片 matmul,再对结果 C 做 AllGather。

AllGather 实际是怎么做的?

如果要沿单个 TPU 轴(一个环)做一维 AllGather,大致做法是:每个 TPU 把自己的 shard 沿环不断传递,直到每个设备都收齐全部 shard。在 GPU 上也可以类似实现:把一个节点内的 GPU 组织成一个环,然后按某种任意顺序在这个环上传递数据块。 下面是一个动画:

<b>图:</b> 在 8 个 TPU 或 GPU 设备上执行 AllGather 的动画。每个设备一开始持有数组的 1/8,最终每个设备都得到完整副本。
图: 在 8 个 TPU 或 GPU 设备上执行 AllGather 的动画。每个设备一开始持有数组的 1/8,最终每个设备都得到完整副本。

AllGather 可以单向进行,也可以双向进行(上图展示的是双向)。如果单向做,那么每个 TPU 要把大小为 $\text{bytes} / N$ 的块沿环传递 $N - 1$ 跳;如果双向做,则需要 $\lfloor \frac{N}{2} \rfloor$ 跳,每跳的数据量是 $2 \cdot \text{bytes} / N$。

它要花多长时间?

我们以双向 AllGather 为例来估算其耗时。设数组大小为 $V$ 字节,收缩维度上总共有 $X$ 个 shard。那么由上图可知,每一跳在两个方向上各发送 $V / \lvert X\rvert$ 字节,所以每一跳耗时:

$$T_{hop} = \frac{2 \cdot V}{\lvert X \rvert \cdot W_\text{ici}}$$

其中 $W_\text{ici}$ 是 双向 ICI 带宽。分子中的 2 来自我们使用的是双向带宽。我们向两个方向各发送 $V / X$,总计就是 $2V / X$。 为了把数据送到所有 TPU,我们总共需要走 $\lvert X\rvert / 2$ 跳严格来说,是 $\lfloor X / 2 \rfloor$。,因此总通信时间为:

$$T_{total} = \frac{2 \cdot V \cdot X}{2 \cdot X \cdot W_\text{ici}}$$ $$T_{total} = \frac{V}{W_\text{ici}}$$

请注意,这个结果和 $X$ 无关! 这很有意思,因为它说明即使 TPU 只是局部连接的,连接的“局部性”并不影响吞吐受限下的总通信时间;真正卡住我们的,是每条链路的速度。

**要点:** 当 AllGather(或 ReduceScatter、AllReduce)处于吞吐受限区间时,实际通信时间只取决于数组大小与可用带宽,而不取决于数组被分片到多少个设备上!

关于 ICI 延迟的一点说明

每一次 ICI hop 都有一个与数据量无关的固定开销,通常约为 1us。因此,当数组很小、每一跳的数据传输时间小于 1us 时,我们就会进入延迟受限(latency-bound)区间,此时计算结果就会依赖于 $X$。

点击查看完整公式 设单跳最小耗时为 $T_\text{min}$,则: $$T_{hop} = \max \left[ T_{min}, \frac{2 \cdot V}{X \cdot W_\text{ici}} \right]$$ $$T_{total} = \max \left[ \frac{T_{min} \cdot X}{2}, \frac{V}{W_\text{ici}} \right]$$ 因为我们需要走 $X / 2$ 跳。对较大的归约或 gather 来说,我们显然是带宽受限的:数据量足够大,单跳固定开销可忽略。但对小数组(例如做模型采样时)就不是这样了,此时 ICI 带宽不再关键,真正的瓶颈是延迟。换句话说,对于某种具体 TPU,例如单向 ICI 带宽为 `4.5e10` 的 TPU v5e,任何小于 `4.5e10 * 1e-6 = 45kB` 的 buffer,都将落在延迟受限区间。

下面是 TPU v5e 8x16 slice 上 AllGather 带宽的经验测量。数组沿大小为 16 的轴分片,因此形成了完整的双向环。

<b>图:</b> TPU v5e 上执行 AllGather 时的经验带宽与估计链路带宽。橙色曲线是实际 AllGather 吞吐(字节/秒),蓝色曲线是根据已知 collective 代价推算出的经验单向链路带宽。
图: TPU v5e 上执行 AllGather 时的经验带宽与估计链路带宽。橙色曲线是实际 AllGather 吞吐(字节/秒),蓝色曲线是根据已知 collective 代价推算出的经验单向链路带宽。

注意,我们不仅达到了标称峰值带宽(4.5e10)的大约 95%,而且大约在 10MB 时就能接近峰值;如果是 16 路分片,相当于每设备只需要约 625kB。(顺便一提:这比 GPU 好得多。)

如果沿多个轴做 AllGather,会发生什么?

当我们沿多个轴做 gather 时,可用于通信的 ICI 维度也随之增多。例如,AllGather_{XY}([B, D_{XY}]) 会同时利用两个硬件 mesh 轴,因此可用带宽会增加为原来的 $N_\text{axes}$ 倍。

在考虑延迟时,一般规则变为:

$$T_{total} = \max \left[ \frac{T_{min} \cdot \sum_{i} |X_i|}{2}, \frac{V}{W_\text{ici} \cdot N_\text{axes}} \right]$$

其中,$\sum_i \lvert X_i \rvert / 2$ 是 TPU mesh 上最长路径的长度。

思考题 2 [AllGather 时间]: 使用第 2 章中的参数,在 TPU v5e 上,对二维 mesh {'X': 8, 'Y': 4} 执行 AllGather_Y([E_Y, F]) → [E, F],若 E = 2048F = 8192 且数据类型是 bfloat16,需要多长时间?如果 E=256, F=256 又如何?

点击查看答案 **答案:** 先算几个基本量: 1. TPU v5e 每个轴的单向 ICI 带宽为 `4.5e10 bytes/s`。 2. 在 (a) 中,数组为 $A[E_Y, F]$,因此每个设备持有形状 `bfloat16[512, 8192]`,大小是 `512 * 8192 * 2 = 8.4MB`。整个数组大小是 `2048 * 8192 * 2 = 34MB`。 **第 (1) 问:** 套用上面的公式,因为只沿 1 个轴做 AllGather,所以有 $$T_{\text{comms}} = \text{34e6} / \text{9e10} = \text{377us}$$ 为了确认没有进入延迟受限区间,注意到轴大小为 4,最多只有 3 跳,所以延迟下界大概只有 `3us`,远小于上面的结果。不过,TPU v5e 只有当某个轴大小为 16 时才有 wraparound 连接,因此这里**实际上不能**做完整的双向 AllGather。边缘设备的数据必须经过 3 跳才能到达另一端,因此理论上更接近: $$T_{\text{comms}} = 3 * \text{8.4e6} / \text{4.5e10} = 560\mu s$$ [这里](https://imgur.com/a/RkvpRGQ)有一个来自[这个 Colab](https://colab.research.google.com/drive/15tDZMfNqm2vJjvSzw5VC9qtSwc5td-oV?usp=sharing)的真实 profile,显示大约是 $680 \mu s$,考虑到通常达不到 100% 理论带宽,这是合理的。 **第 (2) 问:** 每个 shard 的大小是 `64 * 256 * 2 = 32kB`。`32e3 / 4.5e10 = 0.7us`,因此它处于延迟受限区间。由于共有 3 跳,粗略估计就是 `3 * 1us = 3us`。[实际测量更接近 8us。](https://imgur.com/a/HZLQmYs)

**说明:** 当我们写一个二维 mesh,例如 `{'X': 16, 'Y': 4}` 时,并不要求每个轴都必须对应某个特定的**硬件**轴。这意味着它也可以描述一个 `4x4x4` 的 TPU v5p 立方体,其中在物理 $X$ 方向上抽取了两个轴。这一点会在后面讨论多轴数据并行时再次出现。

情况 3:两个乘数的收缩维度都被分片

第三种基本情况是:两个乘数都沿同一个 mesh 轴,在它们的收缩维度上做了分片:

$$\textbf{A}[I, J_X] \cdot \textbf{B}[J_X, K] \rightarrow C[I, K]$$

这时,做本地的分片块矩阵乘法至少仍然是可行的,因为两边共享相同的收缩索引集合。但每个本地乘积都只是最终结果的一部分部分和(partial sum),沿 X 轴的每个设备上会留下不同的部分和。由于这种情形太常见,我们扩展了一下记号,专门标出这种“尚未归约”的状态:

$$\textbf{A}[I, J_X] \cdot_\text{LOCAL} \textbf{B}[J_X, K] \rightarrow C[I, K] \{\ U_X \}$$

其中 $\{U_X\}$ 读作“沿 X mesh 轴未归约(unreduced)”,表示这个结果从某种意义上说还是“未完成”的,因为还差最后一次求和。$\cdot_\text{LOCAL}$ 表示我们执行了本地求和,但把最终结果保留在未归约状态。

从矩阵乘法与外积的关系来看,这一点很好理解:

$$A \cdot B = \sum_{i=1}^{P} \underbrace{A_{:,i} \otimes B_{i,:}}_{\in \mathbb{R}^{n \times m}}$$

其中 ⊗ 表示外积。因此,如果 mesh 轴 X 上的第 i 个 TPU 持有 A 的第 i 列和 B 的第 i 行,那么它就能在本地计算出

$$A_{:,i} \otimes B_{i,:} \in \mathbb{R}_{n\times m}$$

这个矩阵的每个元素,都是最终 A • B 结果中对应位置求和式的第 i 项。我们仍然需要沿 mesh 轴 X 把这些项求和,才能得到完整的 A • B。如果把 AB 写成块(即 shard)的形式,结论也是一样的:对结果的每个块做和即可。

我们可以通过沿 X 轴做一次完整的 AllReduce 来完成这一步:

$$\begin{align*} A[I, J_X] \cdot_\text{LOCAL} B[J_X, K] \rightarrow &\ C[I, K] \{ U_X \} \\ \textbf{AllReduce}_X C[I, K] \{ U_X \} \rightarrow &\ C[I, K] \end{align*}$$

AllReduce 会消除这些部分和,使该轴上的每个设备都拿到完全求和后的值。AllReduce 是本节讨论的第二个关键通信操作;前一个是 AllGather,另外两个是 ReduceScatter 和 AllToAll。AllReduce 接收一个沿某个轴“未归约”(部分求和)的数组,沿该轴把 shard 传递并累加,最终完成求和。它的签名可以写成:

$$\textbf{AllReduce}_Y A[I_X, J] \{U_Y\} \rightarrow A[I_X, J]$$

也就是说,它只会移除后缀 $\{U_Y\}$,其余分片方式保持不变。

AllReduce 的代价是多少?

理解 AllReduce 的一个心智模型是:每个设备把自己的 shard 发给邻居,并把收到的 shard 累加起来。显然,这比 AllGather 更贵,因为这里每个“shard”的形状其实和整个数组一样大。通常来说,一次 AllReduce 的代价大约是 AllGather 的两倍。

一个理解方式是:AllReduce 可以看成两个其他原语的组合,即 ReduceScatterAllGather。与 AllReduce 一样,ReduceScatter 也会消除部分和,但它的输出会沿某个逻辑维度“散开”(scatter)成分片;随后,AllGather 再把这些分散的块重新收集起来,使逻辑轴在该物理轴上恢复未分片状态:

$$\begin{align*} \textbf{ReduceScatter}_{Y,J} : A[I_X,J] \{U_Y\} \rightarrow &\ A[I_X, J_Y] \\ \textbf{AllGather}_Y : A[I_X, J_Y] \rightarrow &\ A[I_X, J] \end{align*}$$

那什么是 ReduceScatter?正如 AllGather 会重新拼回一个分片数组(移除下标),ReduceScatter 会先对一个未归约/部分求和的数组做求和,然后把另一个逻辑轴沿同一个 mesh 轴分片。例如:$X[F]\{U_Y\} \to X[F_Y]$。它的动画和 AllGather 很像,只不过不是把收到的 shard 都保留,而是把它们加起来。因此,除了 reduction 本身的计算外,它的延迟大体和 AllGather 相同。

assets/img/reduce-scatter.gif

每一跳的通信时间,就是每个 shard 的字节数 $V / Y$ 除以带宽 $W_\text{ici}$,和 AllGather 一样。因此有:

$$T_{\text{comms per AllGather or ReduceScatter}} = \frac{V}{W_\text{ici}}$$ $$T_{\text{comms per AllReduce}} = 2 \cdot \frac{V}{W_\text{ici}}$$

其中 $W_\text{ici}$ 仍是双向带宽,前提是我们有一个完整的环来执行归约。

情况 4:两个乘数都在同一个轴上对非收缩维做了分片

一个 mesh 维度在一个张量的分片方式里最多只能出现一次。按照上面的规则推演,有时会得到违反该规则的情形,例如:

$$A[I_X, J] \cdot B[J, K_X] \rightarrow C[I_X, K_X]$$

这是非法的。原因在于:沿 X 维的某个 shard,比如第 i 个,只会持有 C 的第 (i, i) 个 shard,也就是一个对角块。系统中所有 shard 汇总起来,仍然无法恢复除对角块外的其余内容,所以这种分片方式不能成立。

解决办法是:先对某些维度做 AllGather。这里有两种选择:

$$\begin{align*} \textbf{AllGather}_X A[I_X, J] \rightarrow &\ A[I, J] \\ A[I, J] \cdot B[J, K_X] \rightarrow &\ C[I, K_X] \end{align*}$$

或者:

$$\begin{align*} \textbf{AllGather}_X B[J, K_X] \rightarrow &\ B[J, K] \\ A[I_X, J] \cdot B[J, K] \rightarrow &\ C[I_X, K] \end{align*}$$

无论选哪一种,结果的分片形状里都只会出现一次 X。具体选哪一种,要取决于后续操作更需要哪种输出分片。

更深入地理解 TPU 通信原语

前面那 4 种情况已经引入了几种用于执行分片矩阵乘法的“核心通信原语”:

  1. AllGather: 移除分片下标,把各个 shard 收集起来。
  2. ReduceScatter: 消除数组上的“未归约”后缀,沿该轴求和,并让数组沿另一个轴变成分片状态。
  3. AllReduce: 消除“未归约”后缀,并使结果在该轴上保持未分片。

还有一个核心通信原语也必须提一下,它在 Mixture of Experts(MoE)模型和其他计算中经常出现:AllToAll

最后一个通信原语:AllToAll

最后一个基本 collective,在分析分片矩阵乘法时不一定自然出现,但在实践里非常常见:AllToAll,更准确地说,是分片转置(sharded transposition)或重新分片(resharding)的特殊情形。例如:

$$\textbf{AllToAll}_{X, J} A[I_X, J] \rightarrow A[I, J_X]$$

AllToAll 通常用于在分片计算的不同区域之间重排分片布局,使原本不兼容的布局方案能够衔接起来。它在分片的 Mixture-of-Experts 模型中自然出现。你可以把 AllToAll 理解为把一个下标从某个轴挪到另一个轴上。 因为 all-to-all 不需要把每个 shard 的全部数据复制到整个环上,所以它实际上比 AllGather 更便宜(大约便宜 1/4)对于偶数大小的双向环,每个设备向右发送 $(N/2 + (N/2-1) + … + 1)$ 个块,向左发送 $((N/2-1) + … + 1)$ 个块,因此总量为 $0.5 \cdot (N / 2) \cdot (N/2 + 1) + 0.5 \cdot (N / 2) \cdot (N/2 - 1) = N^2/4$。每个块(也就是 shard 的再分块)大小为 $\text{bytes} / N^2$,所以每个设备的代价就是 $(\text{bytes} / N^2) \cdot N^2 / 4 = \text{bytes} / 4$。随着设备数增加,总带宽也增加,因此这个结果会随之扩展到所有设备。

assets/img/all-to-all.gif

如果推广到 N 维 AllToAll,那么在 A x B x C x ... mesh 上,大小为 $V$ 字节的数组,其总开销为:

$$T_\text{comms per AllToAll} = \frac{V \cdot \max(A, B, C, ...)}{4 \cdot N \cdot W_\text{ici}}$$

其中照例 $W_\text{ici}$ 是双向 ICI 带宽。对于一维 mesh,这就退化为 $V / (4 \cdot W_\text{ici})$,即 AllGather 的 1/4。对二维 mesh 而言,它的代价实际上会随着最小轴的变大而进一步下降。

顺带一提,如果你想要一个粗略的推导,可以从一维环面 $\mathbb{Z} / N\mathbb{Z}$ 开始:随机选一个源节点和目标节点,它们平均距离是 $N/4$ 跳,于是代价大致是 $(V \cdot N) / (4N)$。推广到 N 维环面时,各个轴近似独立;每个节点持有 $1/N$ 的数据,平均需要移动 $\max(A, B, C, \ldots)/4$ 跳。

再谈 ReduceScatter

ReduceScatter 比它看起来更基础,因为它其实是 AllGather 的导数(导数在反向传播意义下),反过来也成立。也就是说,如果前向过程中我们有:

$$\textbf{AllGather}_X A[I_X] \rightarrow A[I]$$

那么在反向传播里,我们就会对反向导数 A'(一般来说在各个 shard 上并不相同)做 ReduceScatter,得到分片后的 A'

$$\textbf{ReduceScatter}_X A'[I] \{ U_X \} \rightarrow A'[I_X]$$

同样地,如果前向是 $\text{ReduceScatter}_X(A[I]\{U_X\}) \to A[I_X]$,那么反向就是 $\text{AllGather}_X(A'[I_X]) \to A'[I]$。

点击查看“AllGather 与 ReduceScatter 互为导数”的细节 这背后的原因是:作为线性算子,broadcast 和 reduce 互为转置;而 AllGather 与 ReduceScatter,分别是 broadcast 和 reduce 的外积(也叫 [Kronecker product](https://en.wikipedia.org/wiki/Kronecker_product))。 具体地,设有向量 $x \in \mathbb{R}^n$,设备数为 $p \in \mathbb{N}$,并记 $u = (1, \ldots, 1) \in \mathbb{R}^p$,则可以这样定义 broadcast 与 reduce: $$ \begin{align*} \text{broadcast} &: \mathbb{R}^n \rightarrow \mathbb{R}^{p n} \\ \text{broadcast} &= u \otimes \mathbf{I}_n \\ \text{reduce} &: \mathbb{R}^{p n} \rightarrow \mathbb{R}^n \\ \text{reduce} &= u^T \otimes \mathbf{I}_n \end{align*} $$ 我们看一个简单例子,令 $n = 1$、$p = 2$。如果 $x = (7)$,那么: $$\text{broadcast}(x) = \left(\begin{pmatrix} 1 \\ 1 \end{pmatrix} \otimes \begin{pmatrix} 1 \end{pmatrix}\right) x = \begin{pmatrix} 1 \\ 1 \end{pmatrix} x = \begin{pmatrix} 7 \\ 7 \end{pmatrix} \in \mathbb{R}^{pn}$$ 这正符合我们对 broadcast 的直觉:把 $\mathbb{R}^n$ 里的向量复制到 $\mathbb{R}^{pn}$。再令 $y = (8, 9)$,则: $$\text{reduce}(y) = \left(\begin{pmatrix} 1 & 1 \end{pmatrix} \otimes \begin{pmatrix} 1\end{pmatrix}\right) y = \begin{pmatrix} 1 & 1 \end{pmatrix} \begin{pmatrix} 8 \\ 9 \end{pmatrix} = \begin{pmatrix} 17 \end{pmatrix}$$ 这也符合我们的直觉:把 $\mathbb{R}^{pn}$ 中的向量归约成 $\mathbb{R}^{n}$ 中的向量。因为对任意矩阵 $A, B$ 都有 $(A \otimes B)^T = A^T \otimes B^T$,所以 $\text{reduce} = \text{broadcast}^T$。 于是我们可以写出: $$ \begin{align*} \text{AllGather} &: \mathbb{R}^{pn} \rightarrow \mathbb{R}^{p^2 n} \\ \text{AllGather} &= \text{broadcast} \otimes \mathbf{I}_p \\ \text{ReduceScatter} &: \mathbb{R}^{p^2 n} \rightarrow \mathbb{R}^{pn} \\ \text{ReduceScatter} &= \text{reduce} \otimes \mathbf{I}_p \end{align*} $$ 这里,我们把 $\mathbb{R}^{p^2 n}$ 看成 $\mathbb{R}^{p \times pn}$,也就是每个设备对应一个 $\mathbb{R}^{pn}$ 向量。你可以试着用小例子(例如 `n = 2, p = 3`)把这些算子写成矩阵形式,会更直观。利用同样的转置性质,我们再次得到 $\text{AllGather}^T = \text{ReduceScatter}$,当然也有 $\text{ReduceScatter}^T = \text{AllGather}$。这件事会直接出现在反向传播中:如果 $y = Ax$,其中 $A$ 是 AllGather 或 ReduceScatter 这样的线性算子,那么在反向传播时,给定 $\frac{\partial L}{\partial y}$,就有 $$\frac{\partial L}{\partial x} = A^T \frac{\partial L}{\partial y}$$ 这正说明了:AllGather 的导数是 ReduceScatter,反之亦然。

把一次 AllReduce 拆成 AllGather 和 ReduceScatter 还有一个好处:我们可以把最终的 AllGather 推迟到后面的某个时刻。很多情况下,我们并不想马上为“把整个矩阵乘积重新组装为各设备上的完整复制版本”付出代价;相反,即便在两个输入的收缩维都被分片的情况下,我们仍然希望结果保持某种分片状态:

$$A[I, J_X] \cdot B[J_X, K] \rightarrow C[I, K_X]$$

这时,我们就可以用 ReduceScatter 来替代 AllReduce,然后在更晚的时候视需要再做 AllGather:

$$\begin{align*} A[I, J_X] \cdot_{LOCAL} B[J_X, K] \rightarrow &\ C[I, K] \{ U_X \} \\ \textbf{ReduceScatter}_{X,K} C[I, K] \{ U_X \} \rightarrow &\ C[I, K_X] \end{align*}$$

注意,ReduceScatter 会引入一个新的分片维度,因此这里它自然可以选择把 IK 其中一个命名维度沿 X 做分片。一般来说,在使用 ReduceScatter 时,我们都需要决定“新增的这个分片下标到底加在哪个逻辑轴上”(尽管在实际模型上下文里,这个选择通常是被迫的)。这也是为什么我们写成 ReduceScatterX,K,用来明确指定把哪个轴分片。

如何把 matmul 的通信与计算重叠起来

正如我们在第 1 章中讨论过的,只要通信足够快,我们一般就假设它总能与某些有用计算重叠。本节提到的这些 collective,原则上都可以与矩阵乘法本身重叠,但实现起来并不简单。这里常用的一种算法叫做 collective matmul,最早由 Wang et al. 描述。下面是一个简化后的动画,展示了这种重叠如何实现:

<b>图:</b> 一个分片矩阵-向量乘法如何与其后续 AllReduce(对应上面的情况 3)重叠执行。完整 matmul 由多个矩阵-向量乘法组成。
图: 一个分片矩阵-向量乘法如何与其后续 AllReduce(对应上面的情况 3)重叠执行。完整 matmul 由多个矩阵-向量乘法组成。

简单说,我们可以在计算矩阵的某个块时,同时启动对前一个块的环形归约。在某些情况下,我们还可以沿 batch 维或矩阵输入维做 tile。我们会在第 10 章给出一个简单的 JAX 实现;Mosaic 文档里也有一个很好的 GPU 示例。很推荐你以后亲手实现一次。

我们学到了什么?

assets/img/all-collectives.png
$$T_{\text{comm per AllGather or ReduceScatter}} = \frac{\text{数据量}}{\text{带宽}} \cdot \frac{\text{Axis} - 1}{\text{Axis}} \longrightarrow \frac{\text{数据量}}{\text{双向带宽}}$$
操作 描述 记号 运行时间
AllGather 沿某个轴收集分片数组的所有 shard,移除一个下标。 $[A_X, B] \to [A, B]$ bytes / (双向 ICI 带宽 * num_axes)
ReduceScatter 沿某个轴把部分和求和,并沿另一个轴分片(添加一个下标)。 $[A, B] \{U_X\} \to [A_X, B]$ 与 AllGather 相同
AllReduce 沿某个轴把部分和求和,移除一个 {U_X}。本质上是 ReduceScatter + AllGather。 $[A_X, B]\{U_Y\} \to [A_X, B]$ 2 * AllGather
AllToAll 在同一个轴上,一边 gather(复制)某个维度,一边把另一个维度分片。 $[A, B_X] \to [A_X, B]$ 双向环中约为 AllGather / 4

练习题

下面这些练习都基于本节内容。我们暂时不会给出所有答案,但会尽量补充更多。

问题 1 [复制型分片]: 一个数组按 $A[I_X, J, K, \ldots]$ 分片(也就是说只沿 $X$ 分片),mesh 为 Mesh({'X': 4, 'Y': 8, 'Z': 2})。那么,A 在所有芯片上占用的总字节数,与数组单份大小之比是多少?

点击查看答案 该数组只沿 X 分片,而 X 的大小为 4,因此每个 shard 的大小实际上是 $[I / 4, J, K, \ldots] = \text{sizeof}(A) / 4$。因为它沿 Y 和 Z 都是复制的,所以总大小是 $Y \cdot Z \cdot \text{sizeof}(A)$,于是总大小与单份大小之比就是 $Y \cdot Z = 16$。

问题 2 [AllGather 延迟]: 在 TPU v4p 4x4x4 slice 上,若 mesh 为 Mesh({'X': 4, 'Y': 4, 'Z': 4}),则 $\text{AllGather}_X([B_X, D_Y])$ 在 B=1024D=4096、bfloat16 下大约需要多久?如果改成 $\text{AllGather}_{XY}([B_X, D_Y])$ 呢?如果是 $\text{AllReduce}_Z([B_X, D_Y] \{U_Z \})$ 呢?

点击查看答案 因为这是完整的 `4x4x4` 立方体,所以每个轴都有 wraparound 链路,可用双向带宽为 `9e10`。 1. 因为只沿一个轴 gather,而另一个轴仍然分片,所以本质上是在一个 Y 固定的 shard 上,对 `2BD / Y` 字节做 AllGather。换句话说,沿 X 的 AllGather 看起来像是“对原数组的 `1/Y` 大小做未分片 AllGather”。因此耗时是: $$2BD / (\text{9e10} \cdot Y) = 2 \cdot 1024 \cdot 4096 / (\text{9e10} \cdot 4) = 23 \mu s$$ 2. 这里带宽翻倍了,但我们 gather 的是整个数组,所以: $$T = 2BD / (2 * W) = 2 * 1024 * 4096 / (2 * 9e10) = 46us$$ 这远高于 `4us` 的延迟下界(每 hop 约 `1us`),因此不是延迟受限。 3. AllReduce 的代价是 AllGather 的两倍。每个 shard 大小是 $2BD / (X * Y)$,因此成本约为: $$4BD / (X * Y * W) \approx 4 * 1024 * 4096 / (16 * 9e10) = 11.6us$$

问题 3 [延迟受限的 AllGather]: 假设我们要执行 $\text{AllGather}_X([B_X])$,但 B 很小(例如 128)。在 TPU v4p 4x4x4 slice 上、mesh 为 Mesh({'X': 4, 'Y': 4, 'Z': 4})、数据类型为 bfloat16 时,它应当花多长时间?提示:你大概率是延迟受限。

点击查看答案 这个 bfloat16 数组总共只有 256 字节,每个设备上只有 64 字节。由于 TPU v4p 上该轴大小为 4,我们有 wraparound 链路,可以双向发送。以 `4.5e10` 的单向带宽计算,每 hop 的数据传输时间大约是 `64 / 4.5e10 ~ 0`,所以显然是延迟受限。数一下 hop 数:完整 gather 只需 2 跳,因此粗略估计大约是 `2us`。

问题 4 [matmul 策略]: 为了执行 $X[B, D] \cdot_D Y[D_X, F] \to Z[B, F]$,本节建议的做法是先执行 $\text{AllGather}_X(Y[D_X, F])$,再用完整复制的矩阵做乘法(情况 2,称作 策略 1)。另一种做法是像 $X[B, D_X] \cdot_D Y[D_X, F] \to Z[B, F] \{U_X\}$ 这样先算本地 shard(情况 3,称作 策略 2),然后再执行 $\text{AllReduce}_X(Z[B, F] \{ U_X\})$。两种做法分别需要多少 FLOPs 和通信?哪一种更好?为什么?

点击查看答案 先看基线方案(*策略 1*)。如上所述,AllGather 的代价是 $2DF / W_\text{ici}$。拿到完整复制的数组后,总计算时间为 $2BDF / C$(其中 $C$ 是加速器 FLOPs/s,因为每张 TPU 做的是相同的 FLOPs)。于是: $$T_\text{total (Strategy 1)} = \max\left(\frac{2BDF}{C}, \frac{2DF}{W_\text{ici}}\right)$$ 相比之下,新的策略(*策略 2*)对 $2BF$ 字节做一次 AllReduce,代价为 $4BF / W_\text{ici}$,但它会少做 $1/X$ 的 FLOPs(因为计算被分片了)。也就是说它执行 $2 \cdot B \cdot D \cdot F / X$ FLOPs,而 AllReduce 会对 `bfloat16` 的结果通信 $$2 \cdot 2 \cdot B \cdot F$$ 字节。因此总时间近似为: $$T_\text{total} = \max\left(\frac{2BDF}{X \cdot C}, \frac{4BF}{W_\text{ici}}\right)$$ 问题在于:**哪一个更大?** 如果策略 2 是计算受限,就必须满足 $D / (X \cdot C) > 2 / W_\text{ici}$,也就是 $D / 2X > C / W_\text{ici} \approx 2550$,即 $X < D / (2 * 2550)$。若典型地取 $D \approx 8k$,则意味着大约 $X < 2$,这不太常见,所以策略 2 基本总是通信受限。对于基线方案(策略 1),当 $B < C / W_\text{ici} = 2550$ 时它是通信受限,这种情况很常见,但并非总是如此。 因此,当 $B < 2550$ 时,两种策略都通信受限,这时有: $$T_\text{comms for Strategy 2} < T_\text{comms for Strategy 1} \Leftrightarrow \frac{4BF}{W_\text{ici}} < \frac{2DF}{W_\text{ici}}$$ 也就是当 $D > 2B$ 时成立;而这又意味着 $2B < 5100$。这在实践中经常成立,所以当 batch 较小时,策略 2 有时会更好。 当 batch 较大($B > 2550$)时,我们比较的是: $$T_\text{comms for Strategy 2} < T_\text{math for Strategy 1} \Leftrightarrow \frac{4BF}{W_\text{ici}} < \frac{2BDF}{C}$$ 这等价于 $2 / W_\text{ici} < D / C$,也就是 $D > 2 * 2550 = 5100$。对大模型来说这通常成立,所以除非 $D$ 很小,否则这种替代策略通常更优。 **那为什么我们不总是这么做?** 实际上有时会这样做,但“某个输入沿收缩维度分片,而另一个输入却不沿同一轴分片”的情况本来就不太常见。比如做 FSDP(见[第 5 章](../training))时,参数沿数据维分片,而激活通常也会沿数据维分片,所以这种情况并不多见。

问题 5 [最小延迟]: 假设我要在 TPU v4p 4x4x4 上,以尽可能低的延迟执行矩阵乘法 $A[I, J] \cdot_J B[J, K] \to C[I, K]$。输入可以任意分片,但结果必须是完全复制的。输入应该如何分片?总 FLOPs 和通信时间是多少?

点击查看(部分)答案 这里不提供完整答案,但先列出最可能的四种方案: 1. $A[I_{XYZ}, J] \cdot B[J, K]$,最后做一次 AG 2. $A[I, J] \cdot B[J, K_{XYZ}]$,最后做一次 AG 3. $A[I, J_{XYZ}] \cdot B[J_{XYZ}, K]$,最后做一次 AR 4. $A[I, J] \cdot B[J, K]$(完全复制) 我们当然也可以考虑把不同逻辑轴沿不同 mesh 轴分片,但它大概率不会改变最终成本。对于除 (4) 外的方案,每张 TPU 上的总 FLOPs 都一样,只是通信不同。我们只需分别计算通信代价并选最小者即可。结论先说:**(1) 和 (2) 一样好。**

问题 6: 假设我们想在 TPU v5e 4x4 上执行 $A[I_X, J_Y] \cdot_J B[J_Y, K] \to C[I_X, K]$。需要做什么通信?通信与计算分别花多长时间?

问题 7: 一个典型的 Transformer block 有两个矩阵:$W_\text{in}[D, F]$ 和 $W_\text{out}[F, D]$,其中 $F \gg D$。设 batch size 为 $B$,则整个 block 计算为:

$$In[B, D] \cdot W_\text{in}[D, F] \cdot W_\text{out}[F, D]$$

令 $D=8192$、$F=32768$、$B=128$,并假设数据类型都是 bfloat16。假设我们运行在 TPU v5e 2x2 slice 上,但假装每张 TPU 只有 300MB 空闲内存。为了在不超过内存限制的前提下使总时间最小,In、$W_\text{in}$、$W_\text{out}$ 和 Out 应该如何分片?通信和 FLOPs 分别是多少?提示:最终输出不需要完全复制,但应该和输入采用同样的分片方式,以便“层”可以重复堆叠。

点击查看(部分)答案 先看内存。两个大矩阵每个都占 `2 * 8192 * 32768 = 536MB`。激活 `In` 的大小是 `2 * 128 * 8192 = 2MB`,可以忽略不计。因为每个设备只有 300MB 空闲内存,所以很明显必须对 matmul 做分片。 1. $In[B_X, D] * W_\text{in}[D_{XY}, F] * W_\text{out}[F, D_{XY}] \rightarrow Out[B_X, D]$(通常称为 FSDP) 2. $In[B, D_{XY}] * W_\text{in}[D, F_{XY}] * W_\text{out}[F_{XY}, D] \rightarrow Out[B, D_{XY}]$(这就是张量并行) 第一种方案很差,因为我们需要先对大权重或激活做 AllGather。第二种方案则需要一开始做一次 AllGather,并在结尾做一次 ReduceScatter(比 AllReduce 便宜)。其余计算留给你自己完成。

问题 8 [挑战题]: 以上面的简短代码片段为模板,分配一个分片数组,然后使用 pmapshard_map 对 4 个主要通信原语(AllGather、AllReduce、ReduceScatter、AllToAll)逐一做 benchmark。你会用到 jax.lax.all_gatherjax.lax.psumjax.lax.psum_scatterjax.lax.all_to_all。你是否真正理解了这些函数的语义?它们分别要花多久?

问题 9 [分片 matmul 的另一种策略?]: 在前面的情况 2中,我们说:当 matmul 的两个输入里只有一个沿其收缩维分片时,应该先对这个分片矩阵做 AllGather,再在本地完成收缩。你可能会想到另一种策略:像“两个输入都沿收缩维分片”那样,先做分片 matmul,再对结果做 AllReduce。也就是说,用下面两步实现 $A[I, J_X] *_J B[J, K] \to C[I, K]$:

  1. $C[I, K] \{ U_X \} = A[I, J_X] \cdot B[J_X, K]$
  2. $C[I, K] = \text{AllReduce}(C[I, K] \{ U_X\})$

请回答:

  1. 用显式索引写出这个算法在矩阵 $A[N, M]$ 和 $B[M, K]$ 上的具体形式,说明每个设备做了什么计算。假设 A 在 ND 个设备上按 $A[I, J_X]$ 分片,且希望输出在所有设备上都复制。
  2. 如果你不要求最终结果在每个设备上都有完整复制,而是允许它沿 N 或 K 其中一个维度分片,那么上面的算法应如何修改?
  3. 只比较第 2 问中的通信成本,它与“先 AllGather A 再做 matmul”的通信成本相比如何?
点击查看答案 1. 首先计算外积,结果记为 $$O[N, K] : o_{kj} = \sum_i a_{ki} b_{ij}$$。注意,这里重复指标不是最后真正被收缩掉的那个,因为这里做的是外积。求和范围只覆盖当前设备上持有的那一部分 i 值。例如,如果收缩轴大小是 16、设备数是 4,那么在设备 0 上,i 的范围就是 `{0, 1, 2, 3}`;在设备 1 上是 `{4, 5, 6, 7}`;设备 2 上是 `{8, 9, 10, 11}`;设备 3 上是 `{12, 13, 14, 15}`。然后,对各设备上的 $O[N, K]$ 部分和执行 AllReduce,得到完整的 $O[N, K]$。 2. 如果不要求结果在每个设备上复制,那么第 2 步不必做 AllReduce,改成更便宜的 ReduceScatter 即可,而且可以沿任一维度分片:$[N, K] \{ U_X \} \to [N_X, K]$ 或 $[N, K] \{ U_X \} \to [N, K_X]$。 3. 正如正文所说,在吞吐受限区间内,AllGather 与 ReduceScatter 的代价相同,都只由所处理的“完整矩阵大小”决定。所以,“先 gather 再 matmul”的算法通信量尺度是 $NM$(因为我们 gather 的是 $A$);而“先 matmul 再 reduce-scatter”的算法通信量尺度是 $NK$(因为我们 reduce-scatter 的是输出 $O$)。因此两者通信成本之比为 `M/K`。

问题 10:AllToAll 的趣味题:在上面的表格中,我们提到:在吞吐受限区间内,AllToAll 的时间大约是 AllGather 或 ReduceScatter 的 1/4。这个题目会帮助你理解这个因子 4 从哪里来,也会看到:如果 ICI 链路只有单向而不是双向,这个因子会如何变化。

  1. 先看单向链路。假设有 D 个设备组成一个环拓扑,我们想对一个 N x N 矩阵 $A[I_X, J]$ 执行 AllGather 或 ReduceScatter(为简单起见,假设 D 整除 N)。描述这两个 collective 的通信过程,并计算整个算法执行期间一条 ICI 链路上总共传输了多少标量(浮点或整数)。
  2. 接着看 AllToAll,仍然是单向 ICI。此时算法与 AllGather 有什么不同?计算在该算法中,单条 ICI 链路上总共传输了多少标量。
  3. 你应该能发现,第 (a) 和第 (b) 问的结果之比是一个漂亮的数字。请用直观的话解释这个因子来自哪里。
  4. 现在引入双向通信。它会怎样影响 AllGather 所需的总时间?
  5. 引入双向通信后,AllToAll 所需总时间又会怎样变化?
  6. 最后,请直接解释:为什么在双向环里,AllGather 与 AllToAll 的时间比是 4?
点击查看答案 **(1) 解:** 过程很简单:算法每一步中,每个设备都把一个“条带状”的 shard 发送给最近邻,这个条带大小是 $$\frac{N}{D} \times N$$。这件事总共发生 $$D-1$$ 次,因为每个 shard 需要发给除了起始设备外的所有设备。因此,每个设备总共发送了 $$\frac{N^2(D-1)}{D}$$ 个标量,也就是单条 ICI 链路上流过的总量。 **答案:** $$N^2 (1-\frac{1}{D})$$,当 $$D \gg 1$$ 时可近似写作 $$N^2$$。 **(2) 解:** AllToAll 与 AllGather 在通信层面的关键区别在于:某个设备上的完整 shard 并不需要被发给所有其他设备。设设备 0 上的 shard 是 $$[A, B, C, D]$$(为了便于说明,假设环上只有 4 个设备,且 A、B、C、D 都是矩阵块)。其中块 $$A$$ 不需要发送;块 $$B$$ 最终需要到设备 1;块 $$C$$ 到设备 2;块 $$D$$ 到设备 3。因此第一步里,设备 0 把 $$B, C, D$$ 发给设备 1;第二步里,设备 1 再把 $$C, D$$ 发给设备 2;第三步里,设备 2 把 $$D$$ 发给设备 3。传输参数总数就是 $$\text{block size} * (3 + 2 + 1)$$。一般情形下,每个小块的大小是 $$\frac{N^2}{D^2}$$,而 $$(3 + 2 + 1)$$ 会变成 $$((D-1) + (D-2) + \ldots + 1) = \frac{D(D-1)}{2}$$。因此,单条 ICI 链路上的总传输量为: $$\frac{N^2(D-1)}{D \times 2}$$ **答案:** $$\frac{N^2}{2}(1-\frac{1}{D})$$,当 $$D \gg 1$$ 时可近似写作 $$\frac{N^2}{2}$$。 **(3) 解:** 这个因子就是 $$\frac{1}{2}$$。也就是说,在单向环拓扑中,AllToAll 的成本只有 AllGather/ReduceScatter 的一半。回顾推导可知,本质原因在于:AllGather 中,我们每一步都发送大小相同的整条“小条带”,于是求和类似于 $$\text{block size} * (D + D + D + \ldots + D)$$;而 AllToAll 中,发送的数据块会逐步变少,所以求和变成 $$\text{block size} * (D + (D-1) + (D-2) + \ldots + 1)$$。这个 2 倍差距最终来自著名的公式 $$1 + 2 + \ldots + n = n(n+1)/2$$。 **(4) 解:** 引入双向通信后,任意一条链路需要承载的总标量数会再减半,因为在双向环中,每条“分片条带”都可以同时向两个方向发送。 **(5) 解:** 对 AllToAll 来说,双向通信带来的是 4 倍加速。最直观的理解方式是:看某个分片条带内部各个 `N^2/D^2` 大小的小块的命运。单向时,它们分别需要走 `D-1`、`D-2`、……、`1` 跳;双向时,我们把这些块分为向左和向右移动两类,每个块最多只需走 `floor(D/2)` 跳。于是总和从单向时的大约 $$D^2/2$$ 下降到双向时的大约 $$D^2/8$$,也就是 4 倍改善。 **(6) 解:** 在单向环里,AllToAll 已经比 AllGather 快 2 倍,因为它不需要把整条 strip 发给每一个设备。引入双向通信后,AllToAll 再额外获得 4 倍收益,而 AllGather 只获得 2 倍收益。两者叠加起来,就得到了最终的 4 倍时间比。

第 3 章到这里就结束了!第 4 章(Transformer 数学)请点[这里](../transformers)!

第 4 章

第 4 部分:你需要了解的所有 Transformer 数学知识

这是《如何扩展你的模型》的第 4 部分(第 3 部分:分片 | 第 5 部分:训练

在这里,我们将快速回顾 Transformer 架构,特别是如何计算 FLOPs、字节数以及其他感兴趣的量。

目录


点积计数 (COUNTING DOTS)

让我们从以下形状的向量 $x$, $y$ 和矩阵 $A$, $B$ 开始:

\[\def \red#1{\textcolor{red}{#1}} \def \green#1{\textcolor{green}{#1}} \def \blue#1{\textcolor{blue}{#1}} \def \purple#1{\textcolor{purple}{#1}} \def \orange#1{\textcolor{orange}{#1}} \def \gray#1{\textcolor{gray}{#1}} \begin{array}{cc} \textrm{数组} & \textrm{形状} \\ \hline x & \textrm{[P]} \\ y & \textrm{[P]} \\ A & \textrm{[N P]} \\ B & \textrm{[P M]} \\ \hline \end{array}\]

收缩维度是在操作过程中求和的轴(它们出现在两个输入中但在输出中消失),如矩阵乘法中的内维。批处理维度是出现在两个输入中并原样传递到输出的共享轴;它们索引独立的子问题,不参与 FLOPs 计数的乘法。在 einsum 术语中:出现在两个输入和输出上的标签是批处理;出现在两个输入但不在输出上的标签是收缩。

\[\begin{array}{ccc} \textrm{操作} & \textrm{FLOPs} & \textrm{数据} \\ \hline x \cdot y & 2P & 2P \\ A x & 2NP & NP + P \\ AB & 2NPM & NP + PM \\ [c_0,...,c_N] \cdot [d_0,...,d_N] & 2 \prod c_i \times \prod_{\substack{d_j \notin \blue{BATCH} \\ d_j \notin \red{CONTRACT}}} d_j & \prod c_i + \prod d_j \\ \hline \end{array}\]

请注意,对于矩阵-矩阵乘法,计算量呈立方级 $O(N^3)$ 增长,而数据传输仅呈平方级 $O(N^2)$ 增长——这意味着随着我们扩大矩阵乘法规模,更容易达到计算饱和极限。这非常罕见,并在很大程度上解释了为什么我们使用以矩阵乘法为主的架构——它们适合扩展!

前向和反向 FLOPs

在训练期间,我们并不特别关心给定矩阵乘法的结果;我们真正关心的是它的导数。这意味着我们在反向传播期间执行的 FLOPs 显著更多。

如果我们想象 $B$ 只是大型网络中的一个矩阵,$A$ 是我们的输入激活值,且 $C = A B$,则损失 $L$ 对 $B$ 的导数由链式法则给出:

\[\frac{\partial L}{\partial B} = \frac{\partial L}{\partial C}\frac{\partial C}{\partial B} = A^T \left(\frac{\partial L}{\partial C}\right)\]

这需要 $2NPM$ FLOPs 来计算(因为它在 $N$ 维度上收缩)。同样,损失对 $A$ 的导数是:

\[\frac{\partial L}{\partial A} = \frac{\partial L}{\partial C}\frac{\partial C}{\partial A} = \left(\frac{\partial L}{\partial C}\right) B^T\]

由于 $dL/dC$ 是大小为 $[N, M]$ 的矩阵,这又是 $2NPM$ FLOPs。虽然这个量不是对参数的导数,但它用于计算网络前几层的导数(例如,就像上面的 $dL/dC$ 用于计算 $dL/dB$ 一样)。

将这些相加,我们看到在训练期间,我们总共有 $6NPM$ FLOPs,而推理期间为 $2NPM$:前向传递 $2NPM$,反向传递 $4NPM$。由于 $PM$ 是矩阵中的参数数量,这是著名的 Transformer 训练期间 FLOPs 近似公式 $6 \times \text{参数数量} \times \text{Token 数量}$ 的最简单形式:每个 Token 需要 $6 \times \text{参数数量}$ FLOPs。我们将在下面展示更准确的推导。


TRANSFORMER 账目 (TRANSFORMER ACCOUNTING)

Transformer 是未来。好吧,至少它们是现在。也许几年前,它们只是众多架构之一。但今天,了解该架构的几乎每个细节都是值得的。我们不会重新介绍该架构,但这篇博客原始 Transformer 论文可能会是非常有用的参考。

Transformer Decoder 架构图 图:标准 Transformer 的一层,从上到下流动。我们使用单字母约定来描述 Transformer 中数组的形状 and 布局,再次以红色显示收缩维度,以蓝色显示批处理维度。在给定操作中,输入形状在左上角给出,参数形状在右上角给出,结果形状在下方。

注 [门控 einsum]: 上图使用了“门控 einsums”,我们将上投影矩阵拆分为两个矩阵(上图中的 $W_{In1}$ 和 $W_{In2}$),其输出进行逐元素相乘作为一种“门控函数”。并非所有 LLM 都使用此功能,因此你有时会看到单个 $W_{In}$ 矩阵,MLP 总参数计数为 $2DF$ 而不是 $3DF$。通常在这种情况下,$D$ 和 $F$ 会按比例放大,以保持参数计数与 3 矩阵情况相同。话虽如此,LLaMA、DeepSeek 和许多其他模型都使用了某种形式的门控 einsum。

注 2 [MHA 注意力]: 对于自注意力,$T$ 和 $S$ 相同,但对于交叉注意力,它们可能不同。对于传统的并行多头注意力 (MHA),$N$ 和 $K$ 相同;而对于多查询注意力 (MQA),$K=1$;对于分组查询注意力 (GQA/GMQA),$K$ 仅需能整除 $N$。

注 3 [Pre-norm vs. Post-norm]: 上图显示的是所谓的“Post-norm” Transformer,其中层归一化发生在残差连接之后,即 norm(x + attn(x))。这与原始 Transformer 论文一致,但当今大多数现代 Transformer 使用“Pre-norm”架构,其中归一化发生在残差连接之前,通常为 x + attn(norm(x))。像 LLaMA-3 这样的模型现在就使用这种架构。


全局 FLOPs 和参数计算 (GLOBAL FLOPS AND PARAMS CALCULATION)

为了方便,我们将计算每层 FLOPs,以避免在各处添加因子 $L$。

MLPs

Transformer 的 MLP 通常由 2 个逐元素组合的输入矩阵乘法和单个输出矩阵乘法组成:

\[\begin{array}{ccc} \textrm{操作} & \textrm{训练 FLOPs} & \textrm{参数} \\ \hline \\ A[B,T,\red{D}] \cdot W_{in1}[\red{D}, F] & 6BTDF & DF \\[10pt] A[B,T,\red{D}] \cdot W_{in2}[\red{D}, F] & 6BTDF & DF \\[10pt] \sigma\left(A_{in1}\right)[B,T, F] * A_{in2}[B,T, F] & \gray{O(BTF)} \\[10pt] A[B,T,\red{F}] \cdot W_{out}[\red{F}, D] & 6BTDF & DF \\[10pt] \hline \\ & \approx 18BTDF & 3DF \end{array}\]

注意力 (Attention)

对于具有不同 $Q$ 和 $KV$ 头数的通用分组查询注意力情况,假设 $Q, K, V$ 投影的 head 维度 $H$ 相等,并估算 $QKVO$ 矩阵乘法的成本:

\[\begin{array}{ccc} \textrm{操作} & \textrm{训练 FLOPs} & \textrm{参数} \\ \hline \\ A[B,T,\red{D}] \cdot W_{Q}[\red{D}, N, H] & 6BTDNH & DNH \\[10pt] A[B,T,\red{D}] \cdot W_{K}[\red{D}, K, H] & 6BTDKH & DKH \\[10pt] A[B,T,\red{D}] \cdot W_{V}[\red{D}, K, H] & 6BTDKH & DKH \\[10pt] A[B,T,\red{N}, \red{H}] \cdot W_{O}[\red{N}, \red{H}, D] & 6BTDNH & DNH \\[10pt] \hline \\ & 12BTD(N+K)H & 2D(N+K)H \end{array}\]

点积注意力操作更为微妙,实际上是在 $B, K$ 维度上批处理的 $TH \cdot HS$ 矩阵乘法、一个 softmax 以及同样在 $B, K$ 维度上批处理的 $TS \cdot SH$ 矩阵乘法。我们用蓝色突出显示批处理维度:

\[\begin{array}{cc} \textrm{操作} & \textrm{训练 FLOPs} \\ \hline \\[3pt] Q[\blue{B}, T, \blue{K}, G, \red{H}] \cdot K[\blue{B}, S, \blue{K}, \red{H}] & 6BTSKGH = 6BTSNH \\[3pt] \textrm{softmax}_S \;\; L[B, T, S, K, G] & \gray{O(BTSKG) = O(BTSN)} \\[3pt] S[\blue{B}, T, \red{S}, \blue{K}, G] \cdot V[\blue{B}, \red{S}, \blue{K}, H] & 6BTSKGH = 6BTSNH \\[3pt] \hline \\ & \approx 12BTSNH = 12BT^2NH \\ \end{array}\]

注 [因果掩码]: 大多数最近的 Transformer 使用因果掩码,而不是全双向注意力。在这种情况下,点积操作的有效 FLOPs 减少了 1/2。为了在实践中实现这种减少,我们需要使用注意力算子 (kernel),而不是简单的 einsum。

其他操作

Transformer 中还有其他几种操作。层归一化相对便宜,在一阶成本估算中可以忽略。还有最终巨大的(虽然不是每层都有)反嵌入 (unembedding) 矩阵乘法。

\[\begin{array}{ccc} \textsf{操作} & \textsf{训练 FLOPs} & \textsf{参数} \\ \hline \\ \textrm{layernorm}_D \;\; A[B,T,\red{D}] & \gray{O\left(BTD\right)} & \gray{D} \\[10pt] A[B,T,\red{D}] \cdot W_{unembed}[\red{D}, V] & 6BTDV & DV \\ \end{array}\]

TRANSFORMER FLOPs 的通用经验法则

如果我们忽略短上下文训练的点积注意力成本,那么所有层的总 FLOPs 为:

\[\begin{align*} (18BTDF + 12BTD(N+K)H)L = 6 \times BT \times (3DF + 2D(N+K)H)L \\ = 6 \times \text{Token 数量} \times \text{参数数量} \end{align*}\]

这导致了一个著名的估算稠密 Transformer FLOP 计数的经验法则(忽略注意力 FLOPs)。(反嵌入是另一个简单的矩阵乘法,具有 $6BTDV$ FLOPs 和 $DV$ 参数,也遵循相同的经验法则。)

注意力成本随上下文长度变化的比例

如果我们确实考虑上述点积注意力,并假设 $F=4D$,$D=NH$(典型情况)且 $N=K$:

\[\small{\frac{\textrm{注意力 FLOPs}}{\textrm{矩阵乘法 FLOPs}} = \frac{12BT^2NH}{18BTDF + 24BTDNH} = \frac{12BT^2D}{4 \times 18 BTD^2 + 24 BTD^2} = \frac{12BT^2D}{96 BTD^2} = \frac{T}{8D}}\]

因此,结论是点积注意力 FLOPs 仅在训练期间 $T > 8D$ 时才占据主导地位。对于 $D \approx 8k$,这将是 $\approx 64K$ 个 Token。这很有道理,因为这意味着随着 MLP 尺寸的增加,注意力 FLOPs 变得不那么关键。对于大型模型,注意力的二次方成本实际上并不是长上下文训练的巨大障碍。然而,对于较小的模型,例如 Gemma-27B,$D=4608$,这意味着注意力在 32k 序列长度左右开始占据主导地位。Flash Attention 也有助于减轻长上下文的成本,我们将在附录 A 中简要讨论。


杂项数学 (MISCELLANEOUS MATH)

稀疏性和混合专家模型 (MoE)

如果不简要讨论混合专家 (MoE) 模型,那将是我们的失职。MoE 模型将标准 Transformer 中的单个稠密 MLP 块替换为一组可以动态路由的独立 MLP。初步近似下,一个 MoE 就像一个正常的稠密模型,每层有 $E$ 个 MLP 块,而不是一个。每个 Token 激活其中的 $k$ 个专家,通常 $k \ll E$。比率 $E / k$ 称为稀疏度,通常在 8 到 64 之间(例如 DeepSeek v3 实际上 $k=8, E=256$)。这使参数计数增加了 $O(E)$,同时与稠密版本相比,将每个 Token 的总激活参数数量乘以 $k$。

MoE 层示例 图:一个具有 $n$ 个专家的 MoE 层示例。门控专家将每个 Token 路由到其中的 $k$ 个,并将这 $k$ 个 MLP 的输出相加。我们的参数计数是每个专家大小的 $n$ 倍,但每个 Token 仅使用 $k$ 个。来源

与稠密模型相比,MoE 引入了新的通信,主要是两个 AllToAll(一个在 MoE 块之前,一个在之后),用于将 Token 路由到正确的专家并将其带回其主设备。从技术上讲,这仅在我们沿专家所在的同一轴进行数据或序列分片时才会发生。然而,正如我们在上一节中看到的,每个 AllToAll 的成本仅为沿单轴的同类 AllGather 的 1/4(对于双向环)。

梯度检查点 (Gradient checkpointing)

反向传播作为一种算法,是用计算换取内存。反向传递不需要 $O(n_{layers}^2)$ FLOPs,而是需要 $O(n_{layers})$ 内存,保存前向传递期间生成的所有中间激活值。虽然这比二次方计算好,但内存开销极其昂贵:一个具有 $B \times T=4M$(每批总共 4M 个 Token)、$L=64$ 且 $D=8192$ 的模型,如果要避免所有不必要的反向传递计算,则必须以 bfloat16 格式保存大约 $2 \times 20 \times B \times T \times D \times L = 84TB$ 的激活值。20 来自于(粗略地)计算上述 Transformer 图中的每个中间节点,因为例如:

\[f(x) = \exp(g(x))\] \[\frac{df}{dx} = \exp(g(x)) \cdot \frac{dg}{dx}\]

因此为了避免重新计算,我们需要从前向传递中保存 $g(x)$ 和 $\exp(g(x))$。为了避免保存这么多内存,我们可以选择仅保存中间激活值的一部分。以下是我们使用的一些策略:

这绝不是详尽无遗的。使用 JAX 时,这些通常由 jax.remat/jax.checkpoint 控制。

键值 (KV) 缓存

正如我们将在第 7 节中看到的,LLM 推理有两个关键部分:预填充 (prefill) 和生成 (generation)。

每个 KV 缓存实际上是一个大小为 $[2, S, L, K, H]$ 的数组,其中 2 代表键和值。这相当大!int8 格式的键值缓存总大小为 $2SLKH$。对于一个具有 8k 上下文长度、64 层且 $KH = NH = D = 8192$ 的中等规模模型,这是 $2 \times 8192 \times 64 \times 8192 = 8\text{GiB}$。你可以明白为什么我们想要使用 $K \ll N$ 的 GQA/GMQA。


本节的要点是什么?

组件 每层参数 每层训练 FLOPs
MLP 3DF 18BTDF
注意力 4DNH 24BTDNH + 12BT²NH
其他 D BTD
词表 DV (总计,非每层) 12BTDV

一些练习题

问题 1: 一个 $D=4096, F=4 \times D, V=32,000, L=64$ 的模型有多少参数?其中注意力参数占多大比例?我们的每个 Token 的 KV 缓存有多大?你可以假设 $N\cdot H=D$ 且使用 int8 KV 的多头注意力。

答案: 1. 总参数大约为 $L \times (3DF + 4DNH + D) + 2DV$。对于给定的数字,这是 $64 \times (3 \times 4000 \times 16000 + 4 \times 4000 \times 4000 + 4000) + 2 \times 4000 \times 32000 = 16B$ 参数。 2. 注意力参数与总参数的比率通常为 $4DNH / (4DNH + 3DF) = 4D^2 / (4D^2 + 12D^2) = 1/4$。这告诉我们大约 1/4 的参数用于注意力。 3. 每个 Token 的 KV 缓存为 $2 \times L \times N \times H = 2 \times 64 \times 4096$(int8 格式),即 512kB / token

问题 2:{‘X': 4, ‘Y': 8, ‘Z': 4} 上执行 $A[B_X, D_Y] \times_D W[D_Y, F]$ 总共需要多少 FLOPs?每个 TPU 执行多少 FLOPs?

答案: 该操作的总体“理论” FLOPs 为 $2 \times B \times D \times F$。然而,由于计算未在 Z 维度上分片,我们实际上执行了 Z 倍的额外 FLOPs,这意味着总共 $2 \times B \times D \times F \times Z$ FLOPs。由于计算在其他维度上分片,每个设备的总量大约为 $2 \times B \times D \times F / (X \times Y)$。

问题 3: 执行 $A[I,J,K,L] \times B[I,J,M,N,O] \rightarrow C[K,L,M,N,O]$ 涉及多少 FLOPs?

答案: 遵循上述规则,我们有 I 和 J 作为收缩维度,K, L, M, N 和 O 作为非收缩维度. 我们没有“批处理维度”,所以这只是 $2 \times I \times J \times K \times L \times M \times N \times O$,即所有轴的乘积。如果我们有一个共享轴,它只会被计算一次。

问题 4: 自注意力的算术强度 (arithmetic intensity) 是多少(忽略 Q/K/V/O 投影)?给出作为 Q 和 KV 长度 T 和 S 的函数的答案。在什么上下文长度下,注意力是受计算限制 (FLOPs-bound) 的?鉴于我们 TPU 的 HBM 带宽,请绘制随着上下文长度增长,注意力相对于 FFW 块的有效相对成本。

答案: 自注意力需要加载 $Q, K, V$ 激活值,然后计算 $\text{softmax}(Q \cdot K) \cdot V$,最后将结果写回 HBM。这将使用 Flash Attention 完成,因此数学上会有一些注意事项,但基本上在 bf16 中,自注意力执行: \[\text{Q[B,T,N,H]} \rightarrow_\text{reshape} \text{Q[B, T, K, G, H]} \cdot \text{K[B, S, K, H]} \rightarrow \text{O[B, T, S, K, G]}\] \[U=\text{softmax}_S(\text{O[B, T, S, K, G]})\] \[\text{U[B, T, S, K, G]} \cdot \text{V[B, S, K, H]} \rightarrow \text{X[B, T, K, G, H]}\] 因此我们的总字节数是 $2 \times \text{sizeof}(Q) + 2 \times \text{sizeof(K 或 V)} = 4BTNH + 4BSKH = 4BHK \times (TG + S)$,总 FLOPs 是 $4BTSNH + O(BTSN)$,算术强度是 $4BTSKGH / (4BHK \times (TG + S))$。 基本上,在预填充期间我们有 $S=T$,所以算术强度为 $4BT^2KGH / 4BHKT \times (G+1) = TG/(G + 1) = O(T)$。在生成期间 $T=1$,所以我们有 $4BSKGH / (4BHK \times (G + S)) = SG / (G + S) \rightarrow G$(假设 $S$ 非常大)。根据你对问题的理解,在预填充或训练期间,假设没有序列分片,自注意力在 S=240 时受计算限制。在生成期间,我们永远不会受计算限制,因为 $G$ 很小。尽管如此,你可以看到增加 $G$ 会使我们更接近计算限制。

问题 5: 在什么序列长度下,自注意力 FLOPs 等于 QKVO 投影 FLOPs?

答案: 这纯粹是一个何时 $24BTDNH = 12BT^2NH$ 的问题。简化后我们得到 $2D = T$,例如对于 $D=4096$,这是 $8192$。这告诉我们,对于大多数合理的上下文长度,矩阵乘法 FLOPs 更大。

问题 6: 假设我们在前向传递期间仅保存 Transformer 层中 7 个主要矩阵乘法的输出(Q, K, V, O + 三个 FFW 矩阵)。在反向传递期间,我们需要“重算”多少额外的 FLOPs?

答案: 仅保存七个矩阵乘法输出(Q, K, V, O, W₁, W₂, W₃)意味着反向传递必须重新计算两个注意力矩阵乘法: \[QK^{\top} \quad\text{和}\quad \operatorname{softmax}(QK^{\top})V.\] 以便获得 $\frac{\partial L}{\partial W_O}$。 每个都是在 $B$ 个序列和 $N$ 个头上批处理的 $T \times T$ 矩阵乘法,因此额外的 FLOPs 为: \[4 \; B \, T^{2} \, N \, H.\] 其他重新计算的操作包括: 1. 用于 $\frac{\partial L}{\partial W_{In1}}$ 和 $\frac{\partial L}{\partial W_{In2}}$ 的 $O(BTD)$。 2. 用于 $\frac{\partial L}{\partial W_{Out}}$ 的 $O(BTF)$。

问题 7: DeepSeek v3 表示它在 14.8T Token 上训练了 2.79M H800 小时。鉴于它有 37B 激活参数,他们大约实现了多少硬件利用率?提示:注意他们使用了没有结构化稀疏性的 FP8 FLOPs。

答案: 从规格表来看,我们发现具有稀疏性的 FP8 性能为 3,026 TFLOPs/s,或者通常在没有稀疏性的情况下为一半 ($1.513 \times 10^{15}$ FLOPs/s)。2.79M H800 小时意味着 $2.79 \times 10^6 \times 1.513 \times 10^{15} \times 60 \times 60 = 1.52 \times 10^{25}$ 总 FLOPs。鉴于激活参数计数为 37B,这次训练运行应该使用了大约 $6 \times 37 \times 10^9 \times 14.8 \times 10^{12} = 3.3 \times 10^{24}$ FLOPs。这意味着 FLOPs 利用率约为 $3.3 \times 10^{24} / 1.52 \times 10^{25} \approx 21.7\%$。

问题 8: 混合专家 (MoE) 模型具有标准稠密 MLP 块的 $E$ 个副本,每个 Token 激活其中的 $k$ 个专家。在 TPU v5e 上,权重为 int8 的 MoE 需要多大的 Token 批处理量才能达到计算限制?对于具有 256 个(路由)专家且 $k=8$ 的 DeepSeek,这个数字是多少?

答案: 因为我们有每个专家的 $E$ 个副本,在 int8 中,对于每个权重矩阵,我们需要加载 $E \times D \times F$ 字节。因为每个 Token 激活 $k$ 个专家,对于每个权重矩阵,我们有 $2 \times k \times B \times D \times F$ FLOPs。为了在 bfloat16 FLOPs 下达到计算限制,我们需要算术强度超过 240,这发生在 $(2 \times k \times BDF) / EDF > 240$ 或 $k \times B / E > 120$ 时。 因此,我们需要 $B > 120 \times E / k$ 才能达到计算限制。对于 DeepSeek,这给出了 $B > 120 \times 256 / 8 = 3840$。在生成时,这是一个非常大的批处理量。


附录

附录 A:FLASH ATTENTION 是如何工作的?

将 Transformer 扩展到极长上下文的传统反对意见是,注意力 FLOPs 和内存使用量随上下文长度呈二次方增长。虽然注意力的 QK 乘积确实具有形状 $[B, T, S, N]$(其中 B 是批大小,S 和 T 是 Q 和 K 的序列维度,N 是头数),但这一说法带有一些严肃的注意事项:

  1. 正如我们之前指出的,即使这是二次方的,注意力 FLOPs 仅在 $S > 8 \times D$ 时才占据主导地位,尤其是在训练期间,与内存中存在的所有权重和激活检查点相比,单个注意力矩阵的内存很小,特别是在分片时。
  2. 我们不需要为了计算注意力而实例化完整的注意力矩阵!我们可以计算局部和与最大值,并避免实例化超过一小块数组。虽然总 FLOPs 仍然是二次方的,但我们大大减轻了内存压力。

这第二个观察结果最初由 Rabe 等人 (2021) 提出,后来在 Flash Attention 论文 (Dao 等人 2022) 中得到应用。基本思想是按 K/V 块计算注意力,计算局部 softmax 和一些辅助统计数据,然后将它们传递给下一个块,后者将其与自己的局部块结合。具体来说,我们计算:

  1. M: $q \cdot k$ 在序列维度上的运行最大值。
  2. O: 运行中的完整注意力 softmax 在序列维度上的结果。
  3. L: 运行中的分母 $\sum_i \exp(q \cdot k_i - \text{运行最大值})$。

有了这些,我们只需恒定量的内存即可计算新的最大值、新的运行和以及新的输出。简要描述其工作原理,注意力大致是这个操作:

\[\text{Attn}(Q, K, V) = \sum_i \frac{\exp(Q \cdot K_i - \max_j Q \cdot K_j) V_i}{\sum_l \exp(Q \cdot K_l - \max_j Q \cdot K_j)}\]

减去最大值是为了数值稳定性,并且可以在不影响结果的情况下添加,因为 $\sum_i \exp(a_i + b) = \exp(b) \sum \exp(a)$。仅看上面的分母,如果我们想象有两个连续的键向量块 $K^1$ 和 $K^2$,并且我们为每个块计算局部 softmax 和 $L^1$ 和 $L^2$:

\[L^1 = \sum_i \exp(Q \cdot K_i^1 - \max_j Q \cdot K_j^1)\] \[L^2 = \sum_i \exp(Q \cdot K_i^2 - \max_j Q \cdot K_j^2)\]

然后我们可以通过以下方式将这些结合成这两个块的总 softmax 和:

\[L^\text{combined} = \exp(M^1 - \max(M^1, M^2)) \cdot L^1 + \exp(M^2 - \max(M^1, M^2)) \cdot L^2\]

其中

\[M^1 = \max_j Q \cdot K_j^1 \text{ 且 } M^2 = \max_j Q \cdot K_j^2\]

这也可以对完整的 softmax 执行,为我们提供了一种累积任意大 softmax 和的方法。

Flash Attention 算法

从硬件角度来看,这让我们可以将 Q 块放入 VMEM(上述算法称为片上 SRAM),因此我们只需在每次迭代中加载 KV 块,从而提高了算术强度。我们还可以将运行统计数据保留在 VMEM 中。

最后值得强调的一个微妙点是注意力 softmax 的一个属性,它被用来使 Flash VJP(反向模式导数)计算在训练中变得实用。如果我们定义一个中间 softmax 数组为:

\[S_{ij} = \frac{e^{\tau q_i \cdot k_j}}{\sum_l e^{\tau q_i \cdot k_l}}\]

在注意力中,我们从反向模式 $dO$ 和 $V$ 数组中获得 $dS$:

\[dS_{ij} = dO_{id} \cdot_d V_{jd} = \sum_d dO_{id} V_{jd}\]

在将此梯度反向传播到 Q 和 K 期间:

\[d(q_i \cdot k_j) = (dS_{ij} - S_{ij} \cdot_j dS_{ij}) S_{ij}\]

我们利用一个恒等式,允许我们将沿长键长度维度的收缩替换为沿特征深度维度的局部收缩。

\[\begin{align*} S_{ij} \cdot_j dS_{ij} &= \sum_j \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_k}} \sum_d dO_{id} V_{jd} \\ &= \sum_d dO_{id} \sum_j \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_k}} V_{jd} \\ &= \sum_d dO_{id} O_{id} \\ &= dO_{id} \cdot_d O_{id} \end{align*}\]

这种替换对于能够实现 VJP 的序列块局部计算至关重要,并支持进一步的巧妙分片方案,如环形注意力 (ring attention)。

第 5 章

如何并行化 Transformer 训练

我们所说的扩展(Scaling)是什么意思?

“模型扩展(model scaling)”的目标是,在增加用于训练或推理的芯片数量时,能够实现吞吐量成比例的线性增长(我们称之为强扩展 (strong scaling))。单块芯片的性能取决于内存带宽和浮点运算(FLOPs)之间的权衡,而在集群层面,性能则取决于能否通过与有用的 FLOPs 重叠来隐藏芯片间的通信成本。这并非易事,因为增加芯片数量会增加通信负载,同时又会减少每个设备上可用于隐藏该通信的计算量。正如我们在第 3 部分中看到的,切片矩阵乘法通常需要昂贵的 AllGather 或 ReduceScatter 操作,这可能会阻塞 TPU,使其无法进行有用的工作。本节的目标是找出这些操作在什么时候会变得过于昂贵

在本节中,我们将讨论四种常见的并行化方案:(纯)数据并行 (data parallelism)完全切片数据并行 (fully-sharded data parallelism, FSDP / ZeRO sharding)张量并行 (tensor parallelism)(也称为模型并行)以及(简要讨论)流水线并行 (pipeline parallelism)。对于每种方案,我们将展示会产生怎样的通信成本,以及该成本在何时开始成为计算的瓶颈。我们将重点关注通信瓶颈——虽然内存容量的限制也很重要,但在预训练期间使用重计算(激活检查点,activation checkpointing)以及极大量的芯片时,内存通常不会成为我们的瓶颈。这里我们也不讨论针对 MoE(混合专家模型)的专家并行(expert parallelism),因为那会极大地扩展设计空间,我们只讨论密集型 Transformer 的基础情况。 在本节中,你可以仅关注芯片间的通信成本,因为只要单芯片的批次大小(batch size)足够大,数据从 HBM 传输到 MXU 的过程就已经与计算重叠了。

我们将使用以下符号来简化贯穿本节的计算。

符号 含义 (模型参数)
D dmodel (隐藏层维度/残差流维度)
F dff (前馈网络维度)
B 批次维度 (批次中的 token 数量;指总数,而非单设备数量)
T 序列长度
L 模型中的层数
符号 含义 (硬件特征)
C 每块芯片的 FLOPS/s
W 网络带宽 (双向,通常带下标,如 $W_{\text{ici}}$ 或 $W_{\text{dcn}}$)
X 沿 mesh X 轴的芯片数量
Y 沿备用 mesh Y 轴的芯片数量
Z 沿第三个 mesh Z 轴的芯片数量

为了简单起见,我们将 Transformer 近似看作是 MLP 块的堆叠——正如我们在第 4 部分中看到的,对于较大的模型,注意力机制所占的 FLOPs 比例相对较小。我们也会忽略门控矩阵乘法(gating matmul),从而为每层留下以下简单的结构:

<b>图:</b> 一个简化的 Transformer 层。我们将每个 FFW 块视为两个矩阵的堆叠:<b>W<sub>in</sub></b>: <code>bf16[D, F]</code> (升维投影) 和 <b>W<sub>out</sub></b>: <code>bf16[F, D]</code> (降维投影),输入为 <b>In</b>: <code>bf16[B, D]</code>。
图: 一个简化的 Transformer 层。我们将每个 FFW 块视为两个矩阵的堆叠:Win: bf16[D, F] (升维投影) 和 Wout: bf16[F, D] (降维投影),输入为 In: bf16[B, D]
以下是没有使用任何并行的简化版 Transformer 的完整算法。
**前向传播 (Forward pass):** 需要计算 Loss[B] 1. Tmp[B, F] = In[B, D] *D Win[D, F] 2. Out[B, D] = Tmp[B, F] *F Wout[F, D] 3. Loss[B] = ... **反向传播 (Backward pass):** 需要计算 dWout[F, D], dWin[D, F] 1. dOut[B, D] = ... 2. dWout[F, D] = Tmp[B, F] *B dOut[B, D] 3. dTmp[B, F] = dOut[B, D] *D Wout[F, D] 4. dWin[D, F] = In[B, D] *B dTmp[B, F] 5. dIn[B, D] = dTmp[B, F] \*F Win[D, F] (*前几层需要用到*)
我们提供这个算法是为了与加入了通信的算法进行比较。

接下来是我们将要讨论的 4 种并行化方案。每种方案都可以被认为是由上图中对 InWinWoutOut 的不同切分(sharding)方式来唯一确定的。

1. 数据并行 (Data parallelism): 激活值沿批次维度切片,参数和优化器状态在每个设备上都有完整副本。通信仅发生在反向传播期间。

$$\text{In}[B_X, D] \cdot_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D]$$

2. 完全切片数据并行 (Fully-sharded data parallelism, FSDP / ZeRO-3): 激活值沿批次维度切片(类似于纯数据并行),参数沿相同的 mesh 轴切片,并在前向传播中使用之前“即时(just-in-time)”进行 AllGather。优化器状态也沿批次切片。这种方法减少了重复内存。

$$\text{In}[B_X, D] \cdot_D W_\text{in}[D_X, F] \cdot_F W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D]$$

3. 张量并行 (Tensor parallelism) (也称为 Megatron 切片或模型并行): 激活值沿 D 维度 ($d_\text{model}$) 切片,参数沿 F 维度 ($d_{ff}$) 切片。在每个 block 之前和之后分别对激活值进行 AllGather 和 ReduceScatter。兼容 FSDP。

$$\text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y]$$

4. 流水线并行 (Pipeline parallelism): 权重沿层维度(layer dimension)切片,激活值被分为微批次(microbatch)并沿层维度滚动传递。流水线阶段之间的通信极小(仅需要进行单跳的激活值移动)。为了方便(滥用)符号表示,我们写作:

$$\text{In}[L_Z, B, D][i] \cdot_D W_\text{in}[L_Z, D, F][i] \cdot_F W_\text{out}[L_Z, F, D][i] \rightarrow \text{Out}[L_Z, B, D][i]$$

数据并行

语法: $$\text{In}[B_X, D] \cdot_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D]$$

如果你的模型在即使是非常小的批次大小(>240 tokens,以确保计算密集)下也能放入单块芯片,你总是应该使用简单的数据并行。 纯数据并行将我们的激活值分配到任意数量的 TPU 上,只要 TPU 的数量小于我们的批次大小即可。前向传播不涉及任何通信,但在每个 step 结束时,每个 TPU 都会对其局部梯度执行 AllReduce 操作,以便在更新参数之前进行同步。

<b>图:</b> 纯数据并行(前向传播)示意图。我们的激活值(左侧)沿批次维度完全切片,而我们的权重则是完全复制的,因此每个 TPU 都有一个相同的权重副本。这意味着权重的总内存增加了 N 倍,但在前向传播中不需要通信。
图: 纯数据并行(前向传播)示意图。我们的激活值(左侧)沿批次维度完全切片,而我们的权重则是完全复制的,因此每个 TPU 都有一个相同的权重副本。这意味着权重的总内存增加了 N 倍,但在前向传播中不需要通信。
以下是前向和反向传播的完整算法。纯粹为了简洁,我们滥用符号将 dL/dOut 写为 dOut。
**纯数据并行算法 (Pure Data Parallelism Algorithm):** **前向传播:** 需要计算 Loss[BX] 1. Tmp[BX, F] = In[BX, D] \*D Win[D, F] 2. Out[BX, D] = Tmp[BX, F] \*F Wout[F, D] 3. Loss[BX] = ... **反向传播:** 需要计算 dWout[F, D], dWin[D, F] 1. dOut[BX, D] = ... 2. dWout[F, D] {UX} = Tmp[BX, F] \*B dOut[BX, D] 3. dWout[F, D] = **AllReduce**(dWout[F, D] {UX}) (*不在关键路径上,可以异步完成*) 4. dTmp[BX, F] = dOut[BX, D] \*D Wout[F, D] 5. dWin[D, F] {UX} = In[BX, D] \*B dTmp[BX, F] 6. dWin[D, F] = **AllReduce**(dWin[D, F] {UX}) (*不在关键路径上,可以异步完成*) 7. dIn[BX, D] = dTmp[BX, F] \*F Win[D, F] (*前几层需要用到*)
我们忽略了损失函数的具体细节,并缩写 $\text{Tmp} = W_\text{in} \cdot \text{In}$。请注意,尽管我们最终的损失是平均值 **AllReduce**(Loss[BX]),但我们只需要在反向传播计算权重梯度平均值时才需要执行 AllReduce。

请注意,前向传播没有通信——通信全都在反向传播中!反向传播还有一个很好的特性,那就是 AllReduce 操作不在“关键路径(critical path)”上,这意味着每个 AllReduce 可以在任何方便的时候执行,而不会阻塞你执行后续操作。如果整体通信成本超过了我们的总计算成本,它仍然可能成为我们的瓶颈,但从实现的角度来看,它要宽容得多。我们将会看到模型/张量并行并不具备这个属性。

为什么要这么做? 纯数据并行通过在批次维度上分配激活值来减轻激活内存压力,这允许我们在有更多芯片来切分批次维度时,几乎可以任意增加批次大小。特别是在训练期间,激活值通常主导我们的内存使用,这非常有用。

为什么不这么做? 纯数据并行对于减轻模型参数或优化器状态带来的内存压力没有任何帮助,这意味着在参数 + 优化器状态无法放入单个 TPU 的大规模有趣模型中,纯数据并行很少有用。为了对规模有个概念,假设我们使用 bf16 存储参数,使用 fp32 的 Adam 存储优化器状态进行训练Adam 存储参数、一阶和二阶累加器。因为参数在 bfloat16 中,优化器状态在 float32 中,所以这给出了每个参数 2 + 8 = 10 字节。,我们能放入的最大模型拥有 $$\text{TPU 内存} / 10$$ 的参数量,所以例如在具有 96GB HBM 和使用纯数据并行的 TPUv5p 芯片上,这大约是 90 亿(9B)个参数。

**关键结论:**我们能用 Adam 和纯数据并行训练的最大模型的 $$\text{参数量} = \text{每台设备的 HBM} / 10$$。对于 TPU v5p,这大约是 90 亿参数。请注意,这不包括梯度检查点,所以这实际上并不实用。这是批次为 1 个 token 时的绝对下限。

为了在训练实际模型时使其变得有用,我们至少需要部分地切片模型参数或优化器。

我们什么时候会受限于通信瓶颈? 正如我们在上面看到的,我们每层有两个 AllReduce,每个的大小为 $$2DF$$(针对 bf16 权重)。数据并行在什么时候会让我们受限于通信(communication bound)?

如上表所示,设 $C$ = 单芯片 FLOPs,$W_{\text{ici}}$ = 双向网络带宽,而 $X$ = 批次被划分成的切片数量我们假设这种划分是在 ICI mesh 上完成的,所以相关的网络带宽是 $W_\text{ici}$。让我们计算执行相关矩阵乘法所需的时间 $$T_\text{math}$$,以及所需的通信时间 $$T_\text{comms}$$。由于这种并行方案在前向传播中不需要通信,我们只需计算反向传播时的这些数量。

通信时间: 从前一节我们知道,在 1D mesh 中执行 AllReduce 所需的时间仅取决于被 AllReduce 的数组的总字节数以及 ICI 带宽 $W_\text{ici}$;具体来说,AllReduce 时间是 $2 \cdot \text{总字节数} / W_\text{ici}$。因为我们需要对 $W_\text{in}$ 和 $W_\text{out}$ 都进行 AllReduce,所以每层有 2 个 AllReduce。每个 AllReduce 都是针对一个权重矩阵,即包含 $DF$ 个参数的数组,也就是 $2DF$ 字节。把这些放在一起,单层中进行 AllReduce 的总时间是:

$$\begin{align} T_\text{comms} &= \frac{2 \cdot 2 \cdot 2 \cdot D \cdot F}{W_\text{ici}}. \\ \end{align}$$

矩阵乘法计算时间: 每层在前向传播中包含两个矩阵乘法,或者说在反向传播中包含四个矩阵乘法,每个都需要 $2(B/X)DF$ 次 FLOPs。因此,对于反向传播中的单层,我们有:

$$\begin{align} T_\text{math} &= \frac{2 \cdot 2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \\ \end{align}$$

因为我们交叠执行通信和计算,每层的总时间是这两个数量中的最大值:

$$\begin{aligned} T &\approx \max(\frac{8 \cdot B \cdot D \cdot F}{X \cdot C}, \frac{8 \cdot D \cdot F}{W_\text{ici}}) \\ T &\approx 8 \cdot D \cdot F \cdot \max(\frac{B}{X \cdot C}, \frac{1}{W_\text{ici}}) \end{aligned}$$

当 $$T_\text{math}/T_\text{comms} > 1$$ 时,我们处于计算受限(compute-bound)状态,也就是说:

$$\begin{align} \frac{B}{X} > \frac{C}{W_\text{ici}}. \end{align}$$

结果是,为了在使用数据并行时保持计算受限,我们需要单设备的批次大小 $$B / X$$ 超过 ICI 操作强度 $C / W_\text{ici}$。这归根结底是因为计算时间与单设备批次大小成正比,而通信时间则独立于该数量(因为我们正在传输模型权重)。注意 $B/X > C/W_\text{ici}$ 这一条件与单设备计算受限法则 $B > 240$ 非常相似;在那种情况下,法则同样源于计算时间随批次大小变化,而数据传输大小(在 $B \ll F, D$ 的机制下)与批次大小无关。

让我们代入一些实际的数字来获得规模感。对于 TPUv5p,在 ICI 上进行 1D 数据并行时,C=4.6e14W=2 * 9e10,所以为了避免受限于通信,我们的单芯片批次大小至少要达到 2,550。既然我们可以在多个轴上进行数据并行,如果我们将 TPUv5p pod 的所有三个轴都专用于纯数据并行,我们的带宽 $W_\text{ici}$ 就会乘以 3,而且我们可以降低到每个 TPU 仅 850 的 BS,或者每个 pod(8960个芯片)总计每批次 760 万个 tokens!这告诉我们,我们很难被纯数据并行所阻碍!

**注 [上下文并行 (context parallelism)]:** 在整个本节中,$B$ 始终指代**以 token 计**的总批次大小。显然,我们的批次是由许多不同的序列组成的,那么这又是如何运作的呢?就 MLP 而言,**token 就是 token**!它属于同一序列还是两个不同的序列都无关紧要。因此,我们或多或少可以自由地在批次和序列维度上进行数据并行:我们称之为上下文并行或序列并行,但你可以简单地将其视为另一种数据并行。注意力机制比 MLP 要复杂一些,因为我们要进行一些跨序列的计算,但这可以通过在注意力计算期间收集 KVs 或 Qs 并小心地交叠计算与通信来处理(通常使用被称为“环形注意力 (ring attention)”的技术)。在整个本节中,我们将忽略序列维度,并假设存在某种数量的批次或序列并行。

关于多个 mesh 轴的注记: 我们应该快速注意一下多轴如何影响可用带宽。当我们为给定的并行策略使用多个 mesh 轴时,我们会获得更多的带宽。

完全切片数据并行 (Fully-Sharded Data Parallelism, FSDP)

语法: $$\text{In}[B_X, D] \cdot_D W_\text{in}[D_X, F] \cdot_F W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D]$$

完全切片数据并行(通常称为 FSDP 或 ZeRO 切片)将模型的优化器状态和权重分布在数据并行的各个切片(shard)上,并在需要时高效地对它们进行收集(gather)和分发(scatter)。与纯数据并行相比,FSDP 极大地降低了单设备的内存使用量,节省了反向传播的 FLOPs,且带来的额外开销微乎其微。

<b>图:</b> FSDP 沿着数据维度切分 W<sub>in</sub> 的收缩维度(contracting dimension)和 W<sub>out</sub> 的输出维度。这减少了内存,但(根据第 3 部分)要求我们在执行矩阵乘法之前先收集(gather)W 的权重。请注意,激活值(左侧)<i>并没有沿着收缩维度进行切片</i>,这正是迫使我们需要进行收集操作的原因。<b>请注意,我们的权重优化器状态同样也是沿着收缩维度切片的。</b>
图: FSDP 沿着数据维度切分 Win 的收缩维度(contracting dimension)和 Wout 的输出维度。这减少了内存,但(根据第 3 部分)要求我们在执行矩阵乘法之前先收集(gather)W 的权重。请注意,激活值(左侧)并没有沿着收缩维度进行切片,这正是迫使我们需要进行收集操作的原因。请注意,我们的权重优化器状态同样也是沿着收缩维度切片的。

你还记得(在第 3 部分中)AllReduce 可以被分解为一个 AllGather 和一个 ReduceScatter。这意味着,对于标准数据并行中进行的完整梯度 AllReduce,我们可以不这么做,而是将权重和优化器状态跨芯片切片(shard),在前向传播的每一层按需 AllGather 它们,而在反向传播期间使用 ReduceScatter 来分散梯度,无需支付额外的成本。

这里是 FSDP 的完整算法。
**完全切片数据并行 (FSDP):** **前向传播:** 需要计算 Loss[BX] 1. Win[D, F] = **AllGather**(Win[DX, F]) (*不在关键路径上,可以在计算上一层时提前完成*) 2. Tmp[BX, F] = In[BX, D] \*D Win[D, F] (*现在可以丢弃 Win[D, F] 了*) 3. Wout[F, D] = **AllGather**(Wout[F, DX]) (*不在关键路径上,可以在计算上一层时提前完成*) 4. Out[BX, D] = Tmp[BX, F] \*F Wout[F, D] 5. Loss[BX] = ... **反向传播:** 需要计算 dWout[F, DX], dWin[DX, F] 1. dOut[BX, D] = ... 2. dWout[F, D] {UX} = Tmp[BX, F] \*B dOut[BX, D] 3. dWout[F, DX] = **ReduceScatter**(dWout[F, D] {UX}) (*不在关键路径上,可以异步完成*) 4. Wout[F, D] = **AllGather**(Wout[F, DX]) (*可以提前完成*) 5. dTmp[BX, F] = dOut[BX, D] \*D Wout[F, D] *(可以在此丢弃 Wout[F, D])* 6. dWin[D,F] {UX} = dTmp[BX, F] \*B In[BX, D] 7. dWin[DX, F] = **ReduceScatter**(dWin[D, F] {UX}) *(不在关键路径上,可以异步完成)* 8. Win[D, F] = **AllGather**(Win[DX, F]) (*可以提前完成*) 9. dIn[BX, D] = dTmp[BX, F] \*F Win[D, F] (*前几层需要用到) (可以在此丢弃 Win[D, F]*)

这也被称为 "ZeRO 切片"(ZeRO Sharding),源自 "零冗余优化器"(Zero Redundancy Optimizer),因为我们不执行任何不必要的计算,也不存储任何不必要的状态。ZeRO-{1,2,3} 分别指代以上述方式切分优化器状态、梯度和权重的技术。由于它们都有相同的通信成本从技术上讲,FSDP 在前向传播中增加了一些纯数据并行没有的通信,但它的比例与反向传播相同,因此不应影响通信的“屋顶线(roofline)”。这里的关键是,ZeRO-3 将反向传播的 AllReduce 转换为一个 AllGather 和一个 ReduceScatter,两者具有相同的总通信量。,我们基本上可以总是使用 ZeRO-3 切片,即在设备组上切分参数、梯度和优化器状态。

为什么要这么做? 标准数据并行涉及大量重复工作。每个 TPU 都会对全部梯度进行 AllReduce,然后更新全部优化器状态(所有 TPU 上的工作完全相同),最后更新参数(也是完全重复的)。对于 ZeRO 切片(切分梯度/优化器状态),代替 AllReduce,你可以对梯度执行 ReduceScatter 操作,只更新你负责的那部分优化器状态,并更新一部分参数,然后在前向传播需要时再 AllGather 参数。

什么时候我们会遇到通信瓶颈? 我们相对的 FLOPs 和通信成本与纯数据并行完全相同,因为反向传播中的每个 AllReduce 变成了 AllGather + ReduceScatter。请记住,AllReduce 是由开销各占一半的 AllGather 和 ReduceScatter 实现的。在这里我们对前向传播进行建模,因为它具有与反向传播相同的 FLOPs-通信比:

$$\begin{aligned} T_\text{math} &= \frac{2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \\ T_\text{comms} &= \frac{2 \cdot 2 \cdot D \cdot F}{W_\text{ici}} \\ T &\approx \max\left(\frac{4 \cdot B \cdot D \cdot F}{X \cdot C}, \frac{4 \cdot D \cdot F}{W_\text{ici}}\right) \\ T &\approx 4 \cdot D \cdot F \cdot \max\left(\frac{B}{X \cdot C}, \frac{1}{W_\text{ici}}\right) \end{aligned}$$

因此,如同纯数据并行一样,当 $$B / X > C / W_\text{ici}$$ 时,即单设备批次大小 $B/X$ 超过“ICI 运算强度” $C/W_\text{ici}$(在 TPU v5p 上为 4.59e14 / 1.8e11 = 2550)时,我们受到计算瓶颈限制。这对我们非常有利,因为这意味着如果我们的单设备批次大小足够大以致在纯数据并行下受计算限制,我们就可以——不用担心脱离计算限制区间——直接升级为 FSDP,从而节省了海量的参数和优化器状态内存!虽然我们在前向传播中增加了通信,但这项成本微不足道,因为它直接与前向传播的 FLOPs 重叠隐藏了。

**关键结论:**在 TPU v5 上,当每台设备的批次大小小于 $2550 / M_X$ 时(其中 $M_X$ 是 mesh 轴的数量),FSDP 和纯数据并行都会受到带宽限制。

举例来说,DeepSeek-V2(最近少有的公布其训练批次大小的强力模型之一)使用的批次大小约为 4000 万(40M)tokens。这将使我们能够将规模扩展到大约 47,000 块芯片,即约 5 个 TPUv5 pod,然后才会触及带宽上限。

对于 LLaMA-3 70B 模型,其训练大约消耗了 6.3e24 (15e12 * 70e9 * 6) FLOPs 计算量,我们可以将 16M token 的批次分布在约 16e6 / (2550 / 3) = 18,823 块芯片上(约 2 个由 8960 块芯片组成的 pod),假设每块芯片以 4.59e14 FLOPs 且拥有 50% 峰值利用率(通常称为模型 FLOPs 利用率,MFU)运行,我们可以在大约 17 天内完成训练。相当不错!但让我们继续探讨还能如何做得更好。

**关于临界批次大小的注记 (Note on critical batch size)**: 有些违反直觉的是,随着总批次大小减小(在芯片数量固定的情况下),我们反而更容易遭遇通信瓶颈。数据并行和 FSDP 能够让我们扩展到任意数量的芯片,前提是我们能不断增大批次大小!然而在实践中,随着批次大小增大,我们往往会看到训练边际收益递减,因为梯度变得几乎没有噪音了。我们有时还会观察到训练不稳定。因此,在“无限算力机制”下寻找最佳切片方案的博弈,通常是从一个由 Scaling Law 决定的固定批次大小以及一个已知(庞大)数量的芯片开始,目标是找到一种能将这较小批次分布在如此多芯片上的切分策略。

张量并行 (Tensor Parallelism)

语法: $$\text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y]$$ (我们用 $$Y$$ 来表示之后要与 FSDP 组合)

在完全切片的数据并行的 AllReduce 中,我们在芯片间移动权重。我们也可以对模型的前馈维度进行切片并在层执行期间移动激活值——这被称为“1D 模型并行”或 Megatron 切片。这能为每个 pod 解锁更小的高效批次大小。下图展示了以这种方式切片的单一矩阵乘法示例:

<b>图:</b> 一个基础张量并行的示例。因为我们仅在 Y 维度切片激活值(不像 FSDP 那样在 X 上切片),我们复制了 X 维度的激活值。使用我们的标准语法,这写作 <b>A</b>[B, D<sub>Y</sub>] * <b>B</b>[D, F<sub>Y</sub>] -> <b>C</b>[B, F<sub>Y</sub>]。因为我们仅沿其中一个收缩维度切片,我们通常在矩阵乘法前对激活值 <b>A</b> 执行 AllGather。
图: 一个基础张量并行的示例。因为我们仅在 Y 维度切片激活值(不像 FSDP 那样在 X 上切片),我们复制了 X 维度的激活值。使用我们的标准语法,这写作 A[B, DY] * B[D, FY] -> C[B, FY]。因为我们仅沿其中一个收缩维度切片,我们通常在矩阵乘法前对激活值 A 执行 AllGather。

如前所述,In\[B, DY\] *D Win\[D, FY\] *F Wout\[FY, D\] -> Out\[B, DY\] 意味着我们必须在进行第一次矩阵乘法之前收集(gather)激活值。当激活值尺寸小于权重时,这比 ZeRO 切片代价更低。 这通常仅在叠加了某种程度的 ZeRO 切片(这减少了 gather 操作的大小)时才成立。这就是我们倾向于混合使用 ZeRO 切片和张量并行的原因之一。

这里是张量并行的算法!
**张量并行 (Tensor Parallelism):** **前向传播:** 需要计算 Loss[B] 1. In[B, D] = **AllGather**(In[B, DY]) *(在关键路径上)* 2. Tmp[B, FY] = In[B, D] \*D Win[D, FY] *(收缩维度未切片,所以没有通信)* 3. Out[B, D] {UY} = Tmp[B, FY] \*F Wout[FY, D] 4. Out[B, DY] = **ReduceScatter**(Out[B, D] {UY}) *(在关键路径上)* 5. Loss[B] = ... **反向传播:** 需要计算 dWout[FY, D], dWin[D, FY] 1. dOut[B, DY] = ... 2. dOut[B, D] = **AllGather**(dOut[B, DY]) *(在关键路径上)* 3. dWout[FY, D] = Tmp[B, FY] \*B dOut[B, D] 4. dTmp[B, FY] = dOut[B, D] \*D Wout[FY, D] *(可以在此丢弃 dOut[B, D])* 5. In[B, D] = **AllGather**(In[B, DY]) *(由于可以与前向传播的 (1) 共享,此步骤可以跳过)* 6. dWin[D, FY] = dTmp[B, FY] \*B In[B, D] 7. dIn[B, D] {UY} = dTmp[B, FY] \*F Win[D, FY] *(前几层需要用到)* 8. dIn[B, DY] = **ReduceScatter**(dIn[B, D] {UY}) *(在关键路径上)*

张量并行的一个妙处在于它能与我们 Transformer 前向传播中的两个矩阵乘法很好地交互。最朴素的做法是,我们在两次矩阵乘法后分别进行 AllReduce。但在这里,我们首先执行 In[B, DY] * Win[D, FY] -> Tmp[B, FY] 然后再执行 Tmp[B, FY] * Wout[FY, D] -> Out[B, DY]。这意味着我们在开始时对 In 进行 AllGather,在最后对 Out 进行 ReduceScatter,而不是执行 AllReduce。

这代价有多大? 让我们只对前向传播进行建模——反向传播只是这里的每一步操作的转置。在 1D 张量并行中,我们在第一次矩阵乘法前 AllGather 激活值,在第二次后 ReduceScatter 它们,每次发送两个字节 (bf16)。让我们算出什么时候会遭遇通信瓶颈。

$$\begin{align} T_\text{math} & = \frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} \\ T_\text{comms} & = \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}}\\ \textnormal{T} & \approx \max \left(\frac{4 \cdot B \cdot D \cdot F}{Y \cdot C}, \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}}\right) \end{align}$$

注意到我们想要计算成本大于通信成本,我们得出:

$$\begin{align} \frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} > \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}} \end{align}$$ $$\begin{align} \frac{F}{Y \cdot C} > \frac{1}{W_\text{ici}} \end{align}$$ $$\begin{align} F > Y \cdot \frac{C}{W_\text{ici}} \end{align}$$

因此例如,对于 TPUv5p,在 bf16 精度下 $C / W_{ici} = 2550$,所以我们只能在 $Y < F / 2550$ 时使用张量并行。当我们有多个 ICI 轴时,我们的 $T_\text{comms}$ 会减少 $M_Y$ 倍,所以我们得到 $Y < M_Y \cdot F / 2550$。

**关键结论:**当 $Y > M_Y \cdot F / 2550$ 时,张量并行就会受到通信限制。对于大多数模型来说,这是在 8 路到 16 路张量并行之间。

注意,这不取决于计算精度,因为例如对于 int8 操作,在 TPUv5p 上,$$C_\text{int8} / W_{ici}$$ 是 $$5100$$ 而不是 $$2550$$,但通信量同时也减半了,所以这两个系数互相抵消了。

让我们看一些例子:

结合 FSDP 与张量并行

语法: $$\text{In}[B_X, D_Y] \cdot_D W_\text{in}[D_X, F_Y] \cdot_F W_\text{out}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y]$$

FSDP 和张量并行(Tensor Parallelism, TP)的好处在于它们可以结合使用。通过在两个轴上切片 WinWout,我们同时节省了内存和计算量。由于我们在 X 轴上对 B 进行切片,我们减小了模型并行的 AllGather 的大小,而因为我们在 Y 轴上对 F 进行切片,我们降低了 FSDP 的通信开销。这意味着两者的结合能够让我们实现比前面看到的更低的有效批次大小。

<b>图:</b> 结合 FSDP 与张量并行的图示。与其他情况不同的是,这里没有任何模型参数的重复。
图: 结合 FSDP 与张量并行的图示。与其他情况不同的是,这里没有任何模型参数的重复。
以下是混合 FSDP + 张量并行的完整算法。尽管有大量的通信,但所有的 AllGather 和 ReduceScatter 都更小,因为我们将激活值按批次切片,并将权重做了更细致的张量切片!
**前向传播:** 需要计算 Loss[B] 1. In[BX, D] = **AllGather**Y(In[BX, DY]) *(在关键路径上)* 2. Win[D, FY] = **AllGather**X(Win[DX, FY]) *(可以提前完成)* 3. Tmp[BX, FY] = In[BX, D] \*D Win[D, FY] 4. Wout[FY, D] = **AllGather**X(Wout[FY, DX]) *(可以提前完成)* 5. Out[BX, D] {UY} = Tmp[BX, FY] \*F Wout[FY, D] 6. Out[BX, DY] = **ReduceScatter**Y(Out[BX, D] {UY}) *(在关键路径上)* 7. Loss[BX] = ... **反向传播:** 需要计算 dWout[FY, DX], dWin[DX, FY] 1. dOut[BX, DY] = ... 2. dOut[BX, D] = **AllGather**Y(dOut[BX, DY]) *(在关键路径上)* 3. dWout[FY, D] {UX} = Tmp[BX, FY] \*B dOut[BX, D] 4. dWout[FY, DX] = **ReduceScatter**X(dWout[FY, D] {UX}) 5. Wout[FY, D] = **AllGather**X(Wout[FY, DX]) *(可以提前完成)* 6. dTmp[BX, FY] = dOut[BX, D] \*D Wout[FY, D] *(可以在此丢弃 dOut[B, D])* 7. In[BX, D] = **AllGather**Y(In[BX, DY]) *(不在关键路径上 + 且可以与上一层的 (2) 共享)* 8. dWin[D, FY] {UX} = dTmp[BX, FY] \*B In[BX, D] 9. dWin[DX, FY] = **ReduceScatter**X(dWin[D, FY] {UX}) 10. Win[D, FY] = **AllGather**X(Win[DX, FY]) *(可以提前完成)* 11. dIn[BX, D] {UY} = dTmp[BX, FY] \*F Win[D, FY] *(前几层需要用到)* 12. dIn[BX, DY] = **ReduceScatter**Y(dIn[BX, D] {UY}) *(在关键路径上)*

FSDP 和 TP 的最佳组合是什么? 一个简单但关键的准则是,FSDP 移动的是权重,而张量并行移动的是激活值。这意味着随着批次大小缩小(尤其是当我们进行更多数据并行时),张量并行变得更便宜,因为每个切片的激活值变小了。

因此,通过结合两者,我们能够将每个副本的最小批次大小进一步下压。我们可以像之前一样计算 FSDP 和 TP 的最佳组合量:

设 $$X$$ 为专用于 FSDP 的芯片数量,$$Y$$ 为专用于张量并行的芯片数量。设 $$N$$ 为我们切片中的芯片总数,其中 $$N=XY$$。设 $$M_X$$ 和 $$M_Y$$ 分别为我们在 FSDP 和 TP 上所用的 mesh 轴数量(它们之和应该大致等于 3)。因为前向传播的通信计算比最高,我们纯粹对其进行建模。将上文算法中的通信成本相加,我们得到:

$$T_\text{FSDP comms}(B, X, Y) = \frac{2\cdot 2\cdot D \cdot F}{Y \cdot W_\text{ici} \cdot M_X}$$ $$T_\text{TP comms}(B, X, Y) = \frac{2 \cdot 2 \cdot B \cdot D}{X \cdot W_\text{ici} \cdot M_Y}$$

同样,我们的总 FLOPs 时间是:

$$T_\text{math} = \frac{2\cdot 2 \cdot B \cdot D \cdot F}{N \cdot C}.$$

为了简化分析,我们做两个假设:第一,我们允许 $X$ 和 $Y$ 取非整数值(只要它们为正且满足 $XY=N$);第二,我们假设 $X$ 和 $Y$ 轴上的通信能够相互完全重叠隐藏。在第二个假设下,总通信时间为:

$$T_\text{comms} = \max\left(T_\text{FSDP comms}, T_\text{TP comms}\right)$$

在探究我们会在何种条件下受限于计算(compute-bound)之前,让我们先寻找使总通信量最小的 $X$ 和 $Y$ 最佳值。由于我们的 FLOPs 与 $X$ 和 $Y$ 无关,最佳设置也就是那些使通信最小化的设置。为此,让我们根据 $X$ 和 $N$($N$ 是固定的系统芯片总数),而不是 $X$ 和 $Y$,来改写上面的 $T_\text{comms}$:

$$T_\text{comms} (X) = \frac{4D}{W_\text{ici}} \max\left(\frac{F \cdot X}{N \cdot M_X}, \frac{B}{X \cdot M_Y}\right)$$

因为 $T_\text{FSDP comms}$ 随 $X$ 单调递增,而 $T_\text{TP comms}$ 随 $X$ 单调递减,最大值一定在 $T_\text{FSDP comms} = T_\text{TP comms}$ 时最小,也就是:

$$\begin{align*} \frac{FX_{opt}}{M_X} = \frac{BN}{X_{opt} M_Y} \rightarrow \\ X_{opt} = \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N} \end{align*}$$

这极其有用!这告诉我们在给定 $B$、$F$ 和 $N$ 的情况下,最佳的 FSDP 量是多少。让我们带入一些数值。带入真实的参数,即 $N = 64$(对应一个 4x4x4 的芯片阵列)、$B=48,000$、$F=32768$,可以得到 $X\approx 13.9$。所以我们会选择让 $X$ 为 16 并且 $Y$ 为 4,非常接近我们算出的最优点。

**关键结论:**通常在训练期间,最佳的 FSDP 量为 $$X_{opt} = \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N}$$。

现在让我们回到我们在所有并行策略中一直问的问题:在何种条件下我们会处于计算瓶颈(compute-bound)? 因为我们可以交叠隐藏计算和通信,只要满足以下条件我们就是计算受限的:

$$\max\left(T_\text{FSDP comms}, T_\text{TP comms}\right) < T_\text{math}$$

定义 ICI 算术强度 $\alpha \equiv C / W_\text{ici}$,我们可以简化上式为:

$$\max\left(\frac{F}{Y \cdot M_X}, \frac{B}{X \cdot M_Y}\right) < \frac{B \cdot F}{N \cdot \alpha}$$

由于我们计算出的 $X_{opt}$ 使得不等式左边的两项相等,我们可以将其直接代入任意一项(注意到 $Y_{opt} = N/X_{opt}$),即:

$$\frac{F}{N \cdot W_\text{ici} \cdot M_X} \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N} < \frac{B \cdot F}{N \cdot C}$$

进一步简化,我们发现:

$$ \sqrt{\frac{B\cdot F}{M_X \cdot M_Y \cdot N}} < \frac{B \cdot F}{N \cdot \alpha},$$

其中不等式左边与通信时间成正比,右边与计算时间成正比。请注意,虽然计算时间与批次大小成线性关系(这与并行方式无关),但通信时间与批次大小的平方根成正比。因此,计算时间与通信时间的比值同样与批次大小的平方根成正比:

$$ \frac{T_\text{math}}{T_\text{comms}} = \frac{\sqrt{BF}\sqrt{M_X M_Y}}{\alpha \sqrt{N}}. $$

为了确保这个比值大于一从而让我们处于计算受限状态,我们需要:

$$ \frac{B}{N} > \frac{\alpha^2}{M_X M_Y F}$$

为了得到近似数字,再次代入 $F=32,768$、$\alpha=2550$ 以及 $M_X M_Y=2$(这在 3D mesh 结构下是必然的)。我们大致得到 $B/N > 99$。与纯数据并行(或纯 FSDP)相比,这大约为我们赢得了 8 倍的优化空间;我们之前计算出在纯数据并行且假设 3D mesh 的情况下,$B/N$ 必须超过约 $850$ 才能保持计算受限。

**关键结论:**结合张量并行与 FSDP 允许我们将 $B/N$ 降低至 $$2550^2 / 2F$$。这让我们每块芯片只需处理少至约 100 个的批次,比我们仅使用 FSDP 能实现的水平小了大约 8 倍。

下面,我们绘制了混合 FSDP + TP 的 FLOPs 与通信时间比率图,并在一个典型的 4x4x4 芯片阵列上,将其与纯张量并行 (TP) 以及纯数据并行 (FSDP) 进行了对比。虽然纯 FSDP 并行在超大批次时占据主导优势,但在“批次大小与芯片数量的比率($B/N$)”介于大约 100 到 850 的区间内时,若想达到计算受限(compute-bound)状态,就必须采用混合的 FSDP + TP 策略。

<b>图:</b> 在带有 F=30k 的 TPUv5p 4x4x4 切片上,不同并行的 FLOPs 与通信时间之比。不出所料,张量并行随批次大小有着固定的比率;理想混合的 FSDP + TP 与 $\sqrt{B}$ 成比例扩展,而 FSDP 与 $B$ 成比例扩展。不过在中等批次大小时,只有 FSDP + TP 的比值才能超过 1(计算受限)。
图: 在带有 F=30k 的 TPUv5p 4x4x4 切片上,不同并行的 FLOPs 与通信时间之比。不出所料,张量并行随批次大小有着固定的比率;理想混合的 FSDP + TP 与 $\sqrt{B}$ 成比例扩展,而 FSDP 与 $B$ 成比例扩展。不过在中等批次大小时,只有 FSDP + TP 的比值才能超过 1(计算受限)。

这是另一个展示 TPU v5p 16x16x16 上不同切片方案下 FLOPs 和通信时间作为批次大小函数的例子。

<b>图:</b> 采用不同并行方案时的通信耗时。黑色虚线代表矩阵乘法计算(FLOPs)的耗时,所以任何高于这条线的曲线都意味着通信受限(comms-bound)。我们注意到,在批次大小约低于 6e5 时,所有策略都会受限于通信,这符合我们的预期:4096 * 2550^2 / (2 * 8192 * 4) ≈ 4e5。
图: 采用不同并行方案时的通信耗时。黑色虚线代表矩阵乘法计算(FLOPs)的耗时,所以任何高于这条线的曲线都意味着通信受限(comms-bound)。我们注意到,在批次大小约低于 6e5 时,所有策略都会受限于通信,这符合我们的预期:4096 * 2550^2 / (2 * 8192 * 4) ≈ 4e5。

黑色曲线是在模型 FLOPs 上花费的时间,这意味着只要某批次大小时该黑线低于所有通信成本线,它就会严格受限于通信。你会注意到黑色曲线和绿色曲线大约相交在 4e5 处,正如我们所预测的。

这里有一个交互式动画可以让你操作感受,展示了不同批次大小时的总计算时间和通信时间:

你会注意到,上述规律总体上吻合(即最低点在 FSDP=256, TP=16 附近),只会有一些小浮动,这是因为不同策略的并行轴数量略有不同。

流水线并行 (Pipelining)

你可能已经注意到,我们在前面的章节中完全没有提到流水线并行(Pipelining)。虽然流水线在 GPU 并行策略中占据主导地位,但在 TPU 上却并非不可或缺。简单来说,流水线训练是指将模型的不同层分布在多台设备上,并在前向和反向传播期间,在各个流水线阶段(pipeline stages)之间传递激活值。其算法大致如下:

  1. 在 TPU 0 上初始化数据,权重在层维度上进行了切片(使用 FSDP 和张量并行的流水线中为 $W_\text{in}[L_Z, D_X, F_Y]$)。
  2. 在 TPU 0 上执行第一层,然后将得到的激活值复制给 TPU 1,如此重复,直至到达最后一个 TPU。
  3. 计算损失函数及其导数 $\partial L / \partial x_L$。
  4. 对于最后一个流水线阶段,计算导数 $\partial L / \partial W_L$ 和 $\partial L / \partial x_{L-1}$,然后将 $\partial L / \partial x_{L-1}$ 复制回上一个流水线阶段,重复该过程,直到退回 TPU 0。
这是(可运行的) Python 伪代码 这段伪代码可以在 Cloud TPU VM 上运行。虽然它既不够高效也不太实际,但能让你感受到数据是如何在多台设备间传播的。
batch_size = 32
d_model = 128
d_ff = 4 * d_model

num_layers = len(jax.devices())

key = jax.random.PRNGKey(0)

# 假装每一层只是一个单独的矩阵乘法
x = jax.random.normal(key, (batch_size, d_model))
weights = jax.random.normal(key, (num_layers, d_model, d_model))

def layer_fn(x, weight):
  return x @ weight

# 假设 num_layers == num_pipeline_stages
intermediates = [x]
for i in range(num_layers):
  x = layer_fn(x, weights[i])
  intermediates.append(x)

  if i != num_layers - 1:
    x = jax.device_put(x, jax.devices()[i+1])

def loss_fn(batch):
  return jnp.mean(batch ** 2)  # 随便编造个假的损失函数

loss, dx = jax.value_and_grad(loss_fn)(x)

for i in range(num_layers - 1, -1, -1):
  _, f_vjp = jax.vjp(layer_fn, intermediates[i], weights[i])
  dx, dw = f_vjp(dx)  # 计算 jvp dx @ J(L)(x[i], W[i])
  weights[i] = weights[i] - 0.01 * dw  # 更新权重

  if i != 0:
    dx = jax.device_put(dx, jax.devices()[i-1])

为什么这是一种好方法? 流水线并行有许多显著优势:各个流水线阶段之间的通信成本极低,这意味着即使互连带宽较差,你依然能够训练极其庞大的模型。这在 GPU 上通常非常实用,因为 GPU 并不像 TPU 那样拥有高密度的 ICI(芯片间互连)网络。

为什么它会让人觉得棘手或繁琐? 你可能在上面的伪代码中已经注意到了:TPU 0 几乎一直处于闲置状态!它只在流水线的第一步和最后一步才进行计算。这段空闲期被称为“流水线气泡(pipeline bubble)”,是个非常令人头疼的问题。通常,我们首先尝试的缓解方法是引入微批次(microbatching),即向流水线中连续发送多个小批次,使得 TPU 0 至少在每个完整的训练步(step)内能有更长时间保持利用率。

第二种方法是,巧妙地将前向计算 $W_i @ x_i$、反向的 $dx$ 计算 $W_i @ \partial L / \partial x_{i+1}$ 以及 $dW$ 计算 $\partial L / \partial x_{i+1} @ x_i$ 交叠执行。由于每个操作都需要一定的 FLOPs,我们可以通过将它们在时间上重叠来彻底隐藏这个气泡。这是最近 DeepSeek v3 论文中展示其“无气泡(bubble-free)”流水线调度的一张示意图:

<b>图:</b> DeepSeek v3 的流水线调度 (来自他们 <a href="https://arxiv.org/pdf/2412.19437">最近的论文</a>)。橙色是前向计算,绿色是 dL/dx 计算,蓝色是 dL/dW 计算。通过优先执行反向传播中的 dL/dx 计算,我们可以避免算力闲置搁浅(stranded FLOPs)。
图: DeepSeek v3 的流水线调度 (来自他们 最近的论文)。橙色是前向计算,绿色是 dL/dx 计算,蓝色是 dL/dW 计算。通过优先执行反向传播中的 dL/dx 计算,我们可以避免算力闲置搁浅(stranded FLOPs)。

由于流水线在 TPU(通常通过互连网络组成更大规模的 pod)上相对没那么关键,我们不会对此进行过于深入的探讨,但理解流水线的关键瓶颈依然是一个不错的练习。

跨 Pod 扩展 (Scaling Across Pods)

可能的最大 TPU 切片是一个包含 8960 个芯片(和 2240 台主机)的 TPU v5p SuperPod。如果我们想扩展到这个规模之外,我们需要跨越数据中心网络(DCN,Data-Center Networking)边界。每台 TPU 主机都配备有一个或多个 NIC(网络接口卡),通过以太网将主机连接到其他 TPU v5p pod。正如在TPU 部分所指出的,每台主机拥有约 200Gbps (25GB/s) 全双工的 DCN 带宽,换算下来相当于每个 TPU 约 6.25GB/s 的全双工(出口)带宽。

通常,在跨越单 pod 扩展时,我们在 ICI 域内进行某种形式的模型并行或 FSDP,然后跨多个 pod 执行纯数据并行。设 $N$ 为我们想要扩展的 TPU 数量,并设 $M$ 为每个 ICI 互联切片内的 TPU 数量。为了在 DCN 上执行 AllReduce,我们可以在多个 pod 组成的集合上执行环状归约(ring-reduction),这样便给出了(在反向传播期间的)公式:

$$T_\text{math} = \frac{2 \cdot 2 \cdot 2 \cdot BDF}{N \cdot C}$$ $$T_\text{comms} = \frac{2 \cdot 2 \cdot 2 \cdot DF}{M \cdot W_\text{dcn}}$$

通信带宽与 $M$ 成正比,因为不同于 ICI,这里随着我们扩展 ICI 域并获得更多的 NIC,总带宽是会增加的。简化后我们发现,当以下条件成立时,$T_\text{math} > T_\text{comms}$:

$$\frac{B}{\text{slice}} > \frac{C}{W_\text{dcn}}$$

对于 TPU v5p,此处的 $\frac{C}{W_\text{dcn}}$ 约为 4.46e14 / 6.25e9 = 71,360。这告诉我们,为了有效地跨 DCN 进行扩展,每个 ICI 域内的批次大小存在一个要从各个节点传出的最小值限制。

这是一个多严重的问题呢? 举一个具体的例子,假设我们希望用 2M tokens 的 BS(批次大小)在 TPU v5p 上训练 LLaMA-3 70B 模型。LLaMA-3 70B 有 $F\approx 30,000$。由前几节,我们知道以下几点:

长话短说,如果使用 BS=1M,我们可以使用大约 X (FSDP) = 1024 且 Y (TP) = 8 的理想方案;但对于 BS=2M 的训练,我们就需要使用 DCN。如上所述,我们的 DCN 算术强度是 $\text{71,360}$,所以我们只需要确保我们每个 ICI 域的批次大小大于这个值。这对我们来说简直轻而易举,因为如果是 2 个 pod,我们每个 pod 的 BS 已经是 1M 了,分配给每个 TPU 的批次大小为 111,非常棒(也许稍微有点接近极限,但在理论上是行得通的)。

**关键结论:**只要每个 pod 的批次大小达到至少 71k(71,000)tokens,就可以用纯数据并行较为直接地扩展到多个 TPU pod。

在 TPU 上训练 LLM 的关键结论

策略 描述
数据并行 (Data Parallelism) 激活值沿批次维度切片,其余所有内容完全复制。在反向传播期间,我们通过 all-reduce 同步梯度。
FSDP 激活值、权重和优化器状态均沿批次维度切片,权重在即将使用时才被 Gather 收集(just-in-time),梯度则通过 reduce-scatter 分散。
张量并行 (Tensor Parallelism) 激活值沿 $$d_\text{model}$$ 切片,权重沿 $$d_{ff}$$ 切片;在乘以 Win 之前对激活值执行 gather 收集,其计算结果在乘以 Wout 后通过 reduce-scatter 分散。
混合 FSDP + 张量并行 (Mixed FSDP + Tensor Parallelism) 结合了上述两者,其中 FSDP 对已经历模型并行切分的权重再执行 gather 操作。

下面是每个方法的“推导公式”:

$$\small \begin{array}{cc} \text{策略} & \text{推导公式}\\ \hline \text{DP} & \text{In}[B_X, D] \cdot_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D] \\ \text{FSDP} & \text{In}[B_X, D] \cdot_D W_\text{in}[D_X, F] \cdot_F W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D] \\ \text{TP} & \text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y] \\ \text{TP + FSDP} & \text{In}[B_X, D_Y] \cdot_D W_\text{in}[D_X, F_Y] \cdot_F W_\text{out}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y] \\ \hline \end{array}$$ $$ \small \begin{array}{ccc} \text{策略} & \text{每层计算量} & \text{每层通信量} \\ & \text{(忽略门控 einsum 的计算)} & \text{(字节, 包括前向+反向传播)}\\ \hline \text{DP} & 4BDF/X + 8BDF/X & 0 + 8DF \\ \text{FSDP} & 4BDF/X + 8BDF/X & 4DF + 8DF \\ \text{TP} & 4BDF/Y + 8BDF/Y & 4BD + 4BD \\ \text{FSDP + TP} & 4BDF/(XY) + 8BDF/(XY) & (4BD/X + 4DF/Y) + (8BD/X + 8DF/Y) \\ \hline \end{array}$$

一些练习题

我们以 LLaMA-2 13B 作为本节基础探讨的模型。以下是模型参数:

超参数 (hyperparam) 数值 (value)
L 40
D 5,120
F 13824
N 40
K 40
H 128
V 32,000

LLaMA-2 有独立的嵌入层(embedding)、输出矩阵(output matrix)以及带有门控机制的 MLP 模块。

问题 1: LLaMA-2 13B 有多少参数?(我知道这问题看起来有些傻,但试着算算看)注意,正如在第 4 部分中提到的,LLaMA-3 的前馈网络(FFW)包含 3 个大矩阵:两个升维(up-projection)矩阵和一个降维(down-projection)矩阵。在本章中,我们省略了执行门控的矩阵乘法(einsum),但对于本节的探讨,它的行为与 Win 完全相同。

点击这里查看答案。 * FFW 参数量: $$3LDF$$ = `8.5e9` * 注意力机制参数量: $$4DNHL$$ = `4.2e9` * 词表参数量: $$2VD$$ = `0.33e9` * 总量: `8.5e9 + 4.2e9 + 0.33e9 = 13.0e9`,完全吻合!

问题 2: 假设我们使用 Adam 优化器在总批次大小 BS=16M 下进行训练。在不考虑并行的情况下,模型参数、优化器状态和激活值需要占用多少内存?(假设使用 bf16 存储参数,fp32 存储优化器状态;并假设我们在每层完成 3 个主要的繁重矩阵乘法后对激活值进行检查点重计算(checkpointing activations))。

点击这里查看答案。 参数(bf16)和两个优化器状态(fp32 的一阶和二阶动量累加器)的内存消耗为 `(2 + 4 + 4) * 13e9 ~ 130GB`。前两个矩阵乘法后的激活值大小为 $BF$,最后一个矩阵乘法后的大小为 $BD$(参见上文的 Transformer 示意图),因此在 bf16 精度下,内存总共是 $2 \cdot L \cdot (BD + 2 * BF) = 2LB \cdot (D + 2F)$,即 `2 * 40 * 16e6 * 5,120 * (1 + 2 * 2.7) ~ 4.2e13 = 42TB`(此时 `B=16e6`)。除了这些,所有其他的激活值大小基本上都可以忽略不计。

问题 3: 假设我们正在一个 16x16x16 的 TPUv5p 切片上训练,序列长度为 32k,总批次大小为 3M tokens。继续假设使用 bfloat16 权重和 float32 优化器。

  1. 我们可以使用纯数据并行吗?为什么可以或为什么不可以?
  2. 我们可以单独使用 FSDP 吗?为什么?如果我们单独使用 FSDP,每台设备需要多少内存(假设我们只在每层 3 个繁重的矩阵乘法后检查点重计算激活值)?
  3. 我们可以混合使用 FSDP + 张量并行(TP)吗?为什么可以或不可以?如果可以,我们应该设置多大的 $X$ 和 $Y$?每台 TPU 的内存使用量是多少?仅基于屋顶线(roofline)计算 FLOPs 且忽略注意力机制,假设 40% 的 MFU,每步训练需要多长时间?
点击这里查看答案。 首先理清一下数字。在序列长度为 32k 时,3M tokens 的批次意味着序列批次大小为 96。在一个 16x16x16 的 TPU v5p 切片中,我们总共有 `393TB` 的 HBM。 1. 我们不能使用纯数据并行。纯数据并行意味着我们将在每个芯片上拥有所有的参数和优化器状态。从问题 2 可知这大约是 130GB,远超单芯片可用的 96GB HBM。 2. 让我们先来看内存。将问题 2 中的 BS=16M 替换为 3M,激活检查点占用大约 `~7.86e12` 字节,加上 1.3e11 字节的优化器状态,总计几乎正好是 8e12 = 8TB。TPUv5p 切片总共有 `393TB` 的 HBM,所以我们远没有达到 HBM 限制。接下来看我们是受限于通信还是计算。拥有 4096 个芯片和 3 个并行轴,我们支持的最小批次大小是 `850 * 4096 = 3.48M` tokens。这比我们的 3M 批次大小略大。因此我们实际上受到了通信瓶颈的限制,这很遗憾。所以总的答案是:**不,我们不能只单独使用 FSDP**。 3. 现在我们知道主要问题是通信受限,让我们代入一些数字。首先,从上文我们知道,采用混合 FSDP + 张量并行时,我们的单设备批次大小必须在 $2550^2 / 2F = 235$ 之上。这意味着理论上这是可行的!让我们计算一下各自需要多少。 根据公式 $X_{opt} = \sqrt{(B / F) \cdot (M_X / M_Y) \cdot N}$,这里我们得到 `sqrt(3e6 * 2 * 4096 / 13824) = 1333`,意味着我们将采用大约 1024 路 DP 和 4 路 TP。每台 TPU 的内存如问题 (2) 中所述,单步耗时将是 `6 * 3e6 * 13e9 / (4096 * 4.6e14 * 0.4) = 300ms`。

这就是第 5 部分的全部内容!点击进入[第 6 部分](../applied-training),我们将把这些知识应用到真正的 LLaMA 模型中!

附录

附录 A:推导反向传播的通信量

在前面的章节中,我们将 Transformer 层的前向传播简化为 Out[B, D] = In[B, D] D Win[D, F] F Wout[F, D]。我们是如何推导出反向传播所需的通信量的呢?

这很自然地遵循了前一节中关于单一矩阵乘法 Y = X * A 的推导规则:

$$\frac{dL}{dA} = \frac{dL}{dY}\frac{dY}{dA} = X^T \left(\frac{dL}{dY}\right)$$ $$\frac{dL}{dX} = \frac{dL}{dY}\frac{dY}{dX} = \left(\frac{dL}{dY}\right) A^T$$

通过使用这条推论,我们可以得出以下的公式体系(令 Tmp[B, F] 代表 In[B, D] * Win[D, F]):

1. dWout[F, D] = Tmp[B, F] *B dOut[B, D] 2. dTmp[B, F] = dOut[B, D] *D Wout[F, D] 3. dWin[D, F] = In[B, D] *B dTmp[B, F] 4. dIn[B, D] = dTmp[B, F] *F Win[D, F]

值得注意的是,这些公式仅仅描述了数学上的等式,并没有涉及任何“切片”(sharding)!反向传播计算的目标仅仅是为了推导出这四项数据。为了估算此运算中需要牵涉的网络通信量,你只需在执行上述四步的每一项运算前,根据你在前向传播中设计的切分模式对相应的分块(Tmp, dOut, Wout, Win)进行切片,并解出需要什么样的通信操作即可。请注意,其中的 dOut 是按照与 Out 完全相同的布局被切片的。

第 6 章

在 TPU 上训练 LLaMA 3

本节的目标,是把上一节的结果应用到一个非常实际的问题上:训练 LLaMA 3 系列(herd)模型。与前几节不同,我们希望你亲自动手完成其中许多计算。出于这个原因,我们把每一节的答案都折叠隐藏了,这样你可以先自己尝试解答。试着拿起纸笔,手算一遍吧!

LLaMA 3 是什么样子的?

LLaMA-3 模型家族包含 3 个主要模型:LLaMA 3 8B、70B 和 405B。我们主要聚焦于 70B,把 8B 和 405B 留给你在文末的习题部分自行探索。下面是 LLaMA 3-70B 的架构,摘自 LLaMA 的 HuggingFace 页面

超参数 取值
$$n_\text{layers}$$ (L) 80
$$d_\text{model}$$ (D) 8,192
$$d_{ff}$$ (F) 28,672
$$n_\text{heads}$$ (N) 64
$$n_\text{kv_heads}$$ (K) 8
$$d_\text{qkv}$$ (H) 128
$$n_\text{embeddings}$$ (V) 128,256

为了说明这些信息有多容易获取,下面直接给出配置文件本身,以及相应的参数映射:

assets/img/llama-json.png

为许多不同的开源 LLM 建一张包含这些数字的大表会很有用,这样你就能快速比较它们在设计决策上的差异。

参数量和 FLOPs 计数

问题: 从这张表里,我们能算出 LLaMA 3-70B 的参数量吗?🤫 让我们应用第 4 节的内容,看看能不能得到 70B!

参数 公式 数量
FFW 参数 d_model * d_ff * 3(对应 SwiGLU 的 gate、up 和 down 三个投影)* n_layers 8,192 * 8,192 * 3.5 * 3 * 80 = 56.3e9
词表参数 2(输入和输出嵌入)* n_embeddings * d_model 2 * 128,256 * 8,192 = 2.1e9
注意力参数 n_layers * [ 2(q 嵌入和拼接后的输出投影) d_model * n_heads * d_qkv + 2(k 和 v) d_model * n_kv_heads * d_qkv] 80 * (2 * 8,192 * 64 * 128 + 2 * 8,192 * 8 * 128) = 12e9
56.3e9 + 2.1e9 + 12e9 = 70.4e9

很好!我们得到了符合预期的数字。正如预料的那样,FFW 参数在总参数量中占据了绝对主导地位,尽管注意力部分也并非微不足道。

**要点**:MLP 模块中的 3 个大权重矩阵,远大于 Transformer 中所有其他数组,因此在推理模型内存或 FLOPs 时,我们通常几乎可以忽略其他所有参数。对于 LLaMA 3-70B,这 3 个矩阵占了 70B 参数中的 56B。

现在我们来看看 FLOPs!记住第 4 节中关于训练的一般规则。

问题: LLaMA-3 在每个训练 step 中、每个 token 会执行多少 FLOPs?这能帮助我们判断整个训练过程有多昂贵。

想好之后,点击这里查看答案! **答案**:如[第 4 节](../transformers)所示,我们每个 token 大致会执行 $$6 \cdot \text{param count}$$ 次 FLOPs,因此这里大约是 `6 * 70e9 = 4.2e11` FLOPs / token。也就是每个 step、每个 token 约半个 TFLOP。假设我们受算力限制,那么在单个 TPU v5p 芯片上,这大约需要 `4.2e11 / 4.59E+14 = 1ms`,这里假设 FLOPs 利用率完美。

问题: LLaMA 3 大约使用了 15 万亿 token 进行训练。那么总共是多少 FLOPs?

想好之后,点击这里查看答案! **答案**:这很简单,直接计算 `4.2e11 * 15e12 = 6.3e24 FLOPs`。总共是 6.3 尧次 FLOPs(yottaFLOPs)。很多!如果只用单个 TPU,这会花费 `6.3e24 / 4.59E+14 = 435 年`。这同样非常多!

问题: 假设我们想在一个完整的 TPU v5p pod 上训练,它有 16x20x28 = 8960 个芯片。在 bfloat16 下,假设 MFU 为 40%,并且训练受算力限制,训练需要多久?

想好之后,点击这里查看答案! **答案**:我们知道每个 TPU v5p 每秒能执行 4.59e14 FLOPs。在 40% MFU 下,总耗时约为 `T = 6.3e24 / (8960 * 4.59e14 * 0.4) = 3.8e6 seconds`。**也就是大约 44 天!** 这相当合理,前提是我们真的能达到 40% 的 MFU。

问题: LLaMA 3-70B 预训练时的 batch size 大约是 4M token。要用这个 batch size 进行训练,我们至少需要多少个 TPU?你可以假设参数使用 bfloat16,优化器状态使用 float32,并且每层做 4 次梯度检查点。

想好之后,点击这里查看答案! **答案**:这个问题主要是在问内存占用,因为那是可用算力唯一严格的约束。训练期间,HBM 主要有三类用途:模型参数、优化器状态和梯度检查点。假设权重是 bfloat16、优化器状态是 float32,并采用一种_非常_保守的梯度检查点方案(每层 4 次),则有: | **参数** | 2 * 70GB | ~140GB | | **优化器状态** | 8 * 70GB | ~560GB | | **梯度检查点** | 2 * 8192 * 4e6 * 4 * 80 | ~20.9TB | | **总计** | | ~21.6TB | 总量约为 21.6TB。你会注意到,即便采用非常保守的检查点方案,梯度检查点依然在内存占用中占据绝对主导。技术上我们可以降到每层 1 个检查点,或者采用 microbatching,但这个估算已经足够合理。在这些假设下,由于每个 TPU v5p 拥有 96GB HBM,我们需要 `21.6e12 / 96e9 = 225` 个 TPU。其实这并不算多! *为什么我们不这么做?* 因为训练时间会变成 `44 days * 8960 / 225 = 1752 days`。那几乎是四年。**太久了。** 不过,这也清楚表明,我们使用这些大规模集群并不是因为受内存限制,而是因为我们需要额外的 FLOPs。

问题: 在与上一题相同的假设下,如果我们使用 8960 个 TPU v5p 芯片,每个芯片会占用多少内存?

想好之后,点击这里查看答案! **答案**:总内存仍然约为 21.6TB,因此每个芯片大约占用 2.4GB,几乎可以忽略不计。即便我们采用更激进得多的检查点方案,例如每层 12 个检查点,每个芯片也不过只会占用 8GB。在这种规模下,训练时我们离内存瓶颈还很远。

**要点**:从技术上讲,即便在非常小的拓扑上训练超大模型也是可行的,前提是你愿意接受它可能需要很长时间。只要能够计算一次训练运行的总 FLOPs,我们就可以假设一个适中的 MFU 和已知拓扑,粗略估算出训练时长。

如何为训练切分 LLaMA 3-70B

让我们继续沿用上面的设定:在一个 8960 芯片的 TPU v5p pod 上,用 4M token 的 batch size(每个 batch 为 1024 个长度 4096 的序列)训练 LLaMA 3-70B。下面来讨论,这个模型的最佳切分策略是什么。

问题: 在上述假设下,我们能否仅使用 FSDP 来训练模型?首先,假设我们不能做任何序列/上下文并行。这应该是你首先想到的方案,因为它很简单,而且如果可行的话不会引入额外通信。

想好之后,点击这里查看答案! **答案**:这个回答会稍微有点咬文嚼字。正如上面提到的,LLaMA 3-70B 最初训练时使用的是长度为 4K 的序列,因此 4M token 的 batch size 对应的*序列 batch size*是 1024。这意味着,我们实际上最多只能在 1024 个芯片上做纯数据并行/FSDP,_因为我们只有这么多序列可以拿来做数据并行_。因此,如果按“纯数据并行且不增加额外通信”的朴素理解,答案是否定的。下一个问题会回答一个稍微不那么咬文嚼字的版本。

问题: 现在放宽“不做任何序列切分”的要求。如果我们允许在 batch 轴和 sequence 轴上同时做 FSDP,那么能否只用 FSDP 就在 8960 个芯片上训练 LLaMA 3-70B?

想好之后,点击这里查看答案! **答案**:既然现在允许我们也做序列/上下文并行,就能扩展到更大的规模。首先计算每个设备上的 batch size。如果做 8960 路 FSDP,那么每个 TPU 上的 batch size 为 `4 * 1024 * 1024 / 8960 = 468 tokens`。我们从上一节知道,当 $$\text{per device batch size} < 2550 / M_X$$ 时,FSDP 会受到 ICI 通信瓶颈限制。由于这里在完整 3D pod 上可以使用 3 个轴,因此下界是 850,而我们的数值远低于它。**所以答案仍然是否定的,即使使用 3 个轴也是如此。我们会明显受到通信瓶颈限制。**

问题: 现在来看看混合张量并行和 FSDP。是否存在某种组合可以让我们仍然保持算力受限?如果有,应该采用多大规模的 FSDP 和张量并行?

想好之后,点击这里查看答案! **答案**:首先检查这种方案是否能放得下。我们知道,当每芯片 batch size 小于 $2550^2 / 2F = 113$ 时,我们会受到通信瓶颈限制。正如上面看到的,我们略高于这个阈值。很好!接下来为了选出最优的 FSDP 规模,我们可以使用公式 $$X_{opt} = \sqrt{\frac{2BN}{F}} = \sqrt{\frac{2 \cdot 4.19e6 \cdot 8960}{28672}} = 1618$$ 把它四舍五入到一个合理的 2 的倍数后,大约得到 2048 路 FSDP 和 4 路张量并行。这个方案应该表现不错!

**要点**:我们可以在一个完整的 TPU v5p pod 上,用 4M token 的 batch size,通过混合数据并行(1024 路)、序列并行(2 路)和张量并行(4 路)来训练 LLaMA-3,同时不受通信瓶颈限制。如果尝试纯 FSDP,或 FSDP 加序列并行,我们就会受到通信瓶颈限制。我们在上一节推导出来的那些公式非常实用。

练习题

问题 1【将 LLaMA 70B 扩展到更多芯片】: 假设我们想在 4 个 pod 上、保持相同 batch size 训练 LLaMA 3-70B。我们应该使用什么并行方案?会受算力瓶颈还是通信瓶颈限制?大致需要多久才能完成训练?请务必使用正确的 roofline 上界。

问题 2【LLaMA 405B】:

(a) 使用 LLaMA 3-405B 的 config,像上面一样写出一张包含所有关键超参数的表。这个模型总共有多少参数?每个训练 step 会执行多少 FLOPs?如果训练 15T token,总共会执行多少 FLOPs?

(b) 假设我们想在 8 个 TPU v5p pod 上训练。我们应该使用什么并行方案?训练需要多久?会受算力瓶颈还是通信瓶颈限制?

第 6 节到这里就结束了。点击[这里](../inference)进入第 7 节,了解 Transformer 推理。

第 7 章

Transformer 推理全解析

Transformer 推理基础

假设你已经训练好了一个 Transformer,现在你想用它生成一些新的序列。归根结底,benchmark 分数上升、loss 曲线下降,都只是“真正上路之后会不会发生有趣的事情”的代理指标而已!从历史上看,即使完全不碰推理,你也能在 Transformer 上做出相当多的研究工作。基于打分的多项选择 benchmark,在没有像样的 KV cache 或生成循环实现的情况下,也能高效跑起来。这意味着,尤其是在研究型代码库里,推理路径里常常有大量唾手可得的优化空间。

采样在概念上很简单。我们把一个序列输入进去,喜爱的 Transformer 就会吐出 $$\log p(\text{next token}_i \vert \text{previous tokens})$$,也就是所有可能下一个 token 的对数概率。我们从这个分布中采样,得到一个新 token。把它追加到序列后面,然后重复这个过程,就能得到一段作为提示词续写的 token 序列。

<b>图:</b>Transformer 的朴素采样。蓝色 logits 给出了下一个 token 的概率分布,我们可以从中采样。注意,每一步都会重新处理整个前缀,因此该算法的运行时间是 $\Theta(n^2)$。
图:Transformer 的朴素采样。蓝色 logits 给出了下一个 token 的概率分布,我们可以从中采样。注意,每一步都会重新处理整个前缀,因此该算法的运行时间是 $\Theta(n^2)$。

我们刚刚描述的是 Transformer 采样的朴素实现。虽然它能工作,但实践中我们从来不会这么做,因为每生成一个 token,我们都要重新处理整段序列。这个算法在 FFW 上生成 $$n$$ 个 token 的复杂度是 $$O(n^2)$$,在 attention 机制上的复杂度则是 $$O(n^3)$$!

那我们如何避免这一点? 与其每次都重新做完整前向传播,不如把每次前向传播中的一些中间激活保存起来,这样就能避免重复处理之前的 token。具体来说,在点积注意力中,一个 token 只会关注它之前的 token,因此我们只需把每个 token 的 key 和 value 投影写入一种新的数据结构,即 KV cache。一旦把过往 token 的 key/value 投影保存下来,未来 token 就可以直接计算它们的 $$q_i \cdot k_j$$,而不需要再对更早的 token 做任何新的 FLOPs。很神奇吧!

从这个角度看,推理包含两个关键部分:

下面是带 KV cache 的采样示意图:

<b>图:</b>借助 KV cache 的高效 Transformer 采样示意图。<b style="color: red;">Prefill</b> 会处理 prompt,并把每个 token 的 key-value 激活都保存到缓存中。<b style="color: blue;">Generation</b> 则接收该缓存(以及最后一个 token 的 logits),采样出一个新 token,再将它送入模型,在访问 KV cache 的同时把这个新 token 的 key-value 投影写回缓存。在 MLP block 中,这是一个 $O(n)$ 算法。
图:借助 KV cache 的高效 Transformer 采样示意图。Prefill 会处理 prompt,并把每个 token 的 key-value 激活都保存到缓存中。Generation 则接收该缓存(以及最后一个 token 的 logits),采样出一个新 token,再将它送入模型,在访问 KV cache 的同时把这个新 token 的 key-value 投影写回缓存。在 MLP block 中,这是一个 $O(n)$ 算法。

通过使用 KV cache 进行采样,我们把生成 $n$ 个 token 的时间复杂度降到了:FFW 为 $$O(n)$$,attention 为 $$O(n^2)$$,因为我们再也不需要重新处理先前的 token。不过,要生成一段序列仍然需要进行很多次前向传播。这也就是你在使用 Gemini 或 ChatGPT 时,为什么结果会一边生成一边流式返回。每个 token(通常)都是对一个巨大模型的一次独立 Transformer 调用,只不过其中一部分被缓存了。

我们很快就会看到,prefillgeneration 是两种非常不同的野兽,Transformer 推理其实伪装成了两项任务!相比训练,KV cache 也是一个全新且重要的复杂性来源。

我们真正想优化的是什么?

在继续之前,有必要先强调一个推理中特别全新的因素:延迟。训练时我们只关心吞吐(每个芯片每秒处理的总 token 数),但在推理时,我们必须关心生成 token 的速度,包括 首 token 时间(Time To First Token, TTFT)逐 token 延迟(per-token latency)。例如:

最大化硬件利用率依然至关重要,也有助于降低成本和 TTFT;但与训练不同的是,在所有场景里,它并不必然转化为更好的单用户体验。加速器层面、系统层面和模型架构层面的许多优化,都会在延迟、吞吐、上下文长度,甚至模型质量之间进行权衡。

更细粒度地看待 Transformer

到目前为止,我们大多把 Transformer 视为一堆前馈 block 的堆叠。从 FLOPs 和内存角度看,这通常没问题,但要正确建模推理,这还远远不够。你会在本节反复注意到一点:推理比训练要“娇气”得多。我们通常拥有更少的 FLOPs、更少的批处理机会,而且对延迟更加敏感。KV cache 也让推理复杂度显著增加。正如我们在第 4 部分中看到的,Transformer 前向传播的主要组成部分包括:

  1. 大量线性操作,包括 MLP($W_{in}$、$W_{out}$)以及 attention 中的 QKV 投影与输出投影($W_Q$、$W_K$、$W_V$ 和 $W_O$)。这些操作都需要从 HBM 中读入参数和一批激活,做若干 FLOPs,再把结果写回 HBM。
  2. 点积注意力。我们需要从 HBM 中读取一批 key-value 投影和一批 query 激活,做一些内积和 softmax 运算,再把 attention 结果写回 HBM。
  3. 其他所有东西,包括 layer norm、激活函数、token 采样、更新 KV cache,以及位置编码等。这些也会消耗一些 FLOPs,但通常会被上面两类操作主导,或者被融合进去。

接下来几个小节中,我们会在 prefill 和 generation 的语境下分别观察这些组成部分,并问一个问题:性能最可能卡在哪里?在单个加速器内部,我们是算力受限(compute-bound)还是内存受限(memory-bound)?我们想强调的是,prefill 和 generation 对这个问题的答案会有多么不同。

线性操作:我们的瓶颈是什么?

所有线性操作在概念上都一样,不管它们是在 MLP block 里,还是在 attention 里。它们的算术强度取决于 batch size。我们在第 1 节已经算过一遍,不过值得在这里重述。来看一个单独的矩阵乘法:把一个 $\text{bf16[B, D]}$ batch 与一个 $\text{bf16[D, F]}$ 矩阵相乘。这可以是大的 MLP block($W_\text{in}$ 或 $W_\text{out}$),也可以是较小的 attention 投影之一($W_Q$、$W_K$、$W_V$、$W_O$)。为了做这个 matmul,我们需要把这两个数组从 HBM 读到 MXU 中,完成乘法,然后再把结果写回 HBM。和前文一样,我们有:

$$T_\text{math} = \frac{\text{Computation FLOPs}}{\text{Accelerator FLOPs/s}} = \frac{2BDF}{\text{Accelerator FLOPs/s}}$$ $$T_\text{comms} = \frac{\text{Communication Bytes}}{\text{Bandwidth Bytes/s}} = \frac{2BD + 2FD + 2BF}{\text{Bandwidth Bytes/s}}$$

TPU 或 GPU 可以一边加载一边计算,也就是把二者重叠起来。所以,要成为 compute-bound,我们需要 $$T_\text{math} \geq T_\text{comms}$$,也即:

$$\frac{2BDF}{2BD + 2DF + 2BF} \geq \frac{\text{Accelerator FLOPs/s}}{\text{Bandwidth Bytes/s}} \underset{\text{TPU v5e}}{=} \frac{1.97E+14}{8.20E+11} = 240$$

其中右边就是硬件的算术强度。现在假设 $D$ 和 $F$ 相比 $B$ 都非常大(通常 batch 最多也就 500,而 $D$ 和 $F > 10k$),那么可利用 $\small{2BD + 2DF + 2BF \approx 2DF}$ 来简化分母,得到

$$\begin{align*} \frac{2BDF}{2BD + 2DF + 2BF} \approx \frac{2BDF}{2DF} \geq \frac{\text{Accelerator FLOPs/s}}{\text{Bandwidth Bytes/s}} \\ \underset{\text{TPU v5e}}{=} \frac{1.97E+14}{8.20E+11} \implies B \geq 240 = B_{\text{crit}} \end{align*}$$

如果我们对权重做量化,或在线性代数计算中使用更低精度 FLOPs,那么这个临界 batch size 会变化。例如,如果把权重量化到 int8 或 fp8,$B_\text{crit}$ 会降低 2 倍;如果把 FLOPs 计算也改成 int8 或 fp8,$B_\text{crit}$ 则会上升 2 倍。因此,如果令 $\beta = \text{bits per param} / \text{bits per activation}$,$\alpha_\text{hbm} = C / W_\text{hbm}$,则临界 batch size 实际上为 $B_\text{crit} = \beta \alpha_\text{hbm}$。

**要点:** 当且仅当每个副本上的 **token** batch size 大于 $B_\text{crit} = C / W_\text{hbm} \cdot (\text{bits per param} / \text{bits per activation}) = \beta \cdot \alpha_\text{hbm}$ 时,Transformer 的 matmul 才是 compute-bound。对于 TPU v5e 上的 bf16 激活,这个值是 240 token;对于 H100,大约是 280 token。

训练期间,我们在所有矩阵乘法中都会有很高的算术强度,因为相同的一组权重会在很大的 batch 上被反复复用。这种高算术强度会延续到 prefill 中,因为用户 prompt 通常都有几百甚至几千个 token。 如前所述,TPUv5e 的硬件算术强度是 240,因此如果在这类硬件上用 bf16 运行一个稠密模型,并输入长度超过 240 的序列,我们就预期自己会是 compute-bound,一切都很理想。更短的 prompt 在技术上也可以合并 batching 以提高利用率,但通常没这个必要。

**要点:** 在 prefill 阶段,几乎所有矩阵乘法基本总是 compute-bound。因此,只要最大化硬件利用率或 MFU(Model FLOPs Utilization),就足以同时最大化单芯片吞吐(成本)和延迟(以 TTFT 形式体现)。除非 prompt 极短,否则按 prompt 维度做 batching 往往只会增加延迟,而对 prefill 吞吐的改善很有限。

然而,在 generation 阶段,对于每个请求,我们一次只能做一个 token 的前向传播,因为步骤之间存在串行依赖!因此,想要(相对容易地)获得较好的利用率,只能把多个请求批在一起,沿 batch 维度并行。稍后我们还会展开讨论,但要在不伤害延迟的前提下把大量并发请求真正 batch 到一起,其实很难。因此,generation 更难把硬件 FLOPs 跑满。

**要点:** 在 generation 阶段,总 token batch size 必须大于 $B_{\text{crit}}$,线性 / 前馈操作才会成为 compute-bound(对 TPU v5e 上的 bf16 参数而言是 240)。由于 generation 是逐 token 串行进行的,这意味着我们必须把多个请求批在一起,而这并不容易!

值得注意的是,这个数真的很大! 生成 batch size 为 240,意味着要同时有 240 个并发请求在生成,并且对稠密模型而言,还要有 240 份独立的 KV cache。所以在实践里,这通常很难做到,除了一些大规模离线推理场景。相比之下,在 prefill 时一次通过超过 240 个 token 是非常常见的,不过随着稀疏性增加,还是需要多加留意。

还要注意,这个具体数值会随着量化方式和硬件不同而变化。 加速器在更低精度下通常能提供更多算力。比如,如果参数是 int8、计算用 bf16,那么临界 batch size 会降到 120。若激活和参数都用 int8,则它又回到 240,因为 TPUv5e 可以提供 400 TOPs/s 的 int8 x int8 算力。

Attention 又如何?

当我们观察点积注意力操作时,情况会变得更复杂,尤其是还得把 KV cache 算进去。先只看一个注意力头,并假设是纯多头注意力。在单个 Flash Attention 融合核中,我们这里我们做了不少简化,忽略了 softmax、mask 等非 matmul FLOPs。它们应该与计算或 HBM 读取重叠,但在某些 TPU 世代上,这并不总是容易做到。尽管这些细节不会改变主要结论,即 KV cache 往往是 memory-bound,但它们依然值得关注。

  1. 从 HBM 读取形状为 $\text{bf16[B, T, D]}$ 的 $Q$ 激活。
  2. 读取 $KV$ cache,它是一对形状为 $\text{bf16[B, S, D]}$ 的张量。
  3. 在 $$QK$$ matmul 中执行 $2BSTD$ FLOPs。借助 Flash Attention,我们不需要把 $\text{bf16[B, S, T]}$ 的注意力矩阵写回 HBM。
  4. 在 attention 的 $$AV$$ matmul 中执行 $2BSTD$ FLOPs。
  5. 把得到的 $\text{bf16[B, T, D]}$ 张量写回 HBM。

把这些合起来,我们得到:

$$\text{Multiheaded Attention Arithmetic Intensity} = \frac{4BSTD}{4BSD + 4BTD} = \frac{ST}{S+T}$$

对于 prefill,由于我们做的是自注意力,所以 $S=T$,这就化简成 $T^2 / 2T = T / 2$。这很好,因为它意味着 prefill 阶段注意力的算术强度是 $\Theta(T)$。也就是说,只要序列长度足够大,attention 很容易就是 compute-bound!

但在 generation 阶段,序列维度几乎退化,而 $B$ 和 $D$ 维又会被抵消,因此我们可以近似写成:

$$S \gg T = 1 \implies \frac{ST}{S+T} \approx 1$$

这就糟糕了,因为它意味着我们无法提升 generation 阶段 attention 的算术强度。我们做的 FLOPs 非常少,却要加载巨大的 KV cache。因此,attention 在 generation 时基本总是受限于内存带宽!

**要点:** 在 prefill 阶段,只要序列长度稍微合理一些(大约 $\gt 480$ token),attention 通常就是 compute-bound;而在 generation 阶段,它的算术强度低且恒定,因此我们总是 memory bandwidth-bound。

从概念上看,为什么会这样? 主要原因是:在线性部分,我们之所以 compute-bound,是因为参数(那些消耗内存带宽的大头)会被许多 batch item 复用。但每个 batch item 都有自己独立的 KV cache,所以 batch size 越大,就意味着 KV cache 越多。除非你对架构做了相当激进的调整,否则这里几乎总是会受限于内存带宽。

这也意味着:一旦参数内存与 KV cache 内存处于同一量级,继续增大 batch size 所能带来的吞吐收益就会递减。收益递减有多严重,取决于单个序列上参数字节数与 KV cache 字节数之比,也就是大致的 $2DF / SHK$。由于 $HK\approx D$,它大致取决于 $F$ 与 $S$(序列长度)的比值。当然,这也取决于那些能让 KV cache 更小的架构改动(稍后会详细谈)。

LLM 延迟与吞吐的理论估算

通过上面的数学推导,我们可以对优化时应该争取的 step time 给出相当不错的界限。(注意:如果希望读者从整章里只记住一件事,那大概就是下面这个公式。) 对 generation 中常见的小 batch size 而言,我们可以假设 attention 和 MLP block 都受限于内存带宽,从而给每一步延迟一个下界:

$$\begin{equation*} \text{Theoretical Min Step Time} = \frac{\text{Batch Size} \times \text{KV Cache Size} + \text{Parameter Size}}{\text{Total Memory Bandwidth}} \end{equation*}$$

类似地,对吞吐而言:

$$\begin{equation*} \text{Theoretical Max Tokens/s} = \frac{\text{Batch Size} \times \text{Total Memory Bandwidth}}{\text{Batch Size} \times \text{KV Cache Size} + \text{Parameter Size}} \end{equation*}$$

最终,随着 batch size 增大,FLOPs 开始主导参数加载,于是更一般的实际公式变成:

$$\begin{align} \tiny \text{Theoretical Step Time (General)} = \underbrace{\frac{\text{Batch Size} \times \text{KV Cache Size}}{\tiny \text{Total Memory Bandwidth}}}_{\text{Attention (always bandwidth-bound)}} + \underbrace{\max\left(\frac{2 \times \text{Batch Size} \times \text{Parameter Count}}{\text{Total FLOPs/s}}, \frac{\text{Parameter Size}}{\text{Total Memory Bandwidth}}\right)}_{\tiny \text{MLP (can be compute-bound)}} \end{align}$$

其中 attention 这一项(左边)永远不是 compute-bound,因此不需要再套一个 FLOPs roofline。这个公式非常适合做拍脑袋量级估算,例如:

随堂小测: 假设我们想在 TPU v5e 4x4 slice 上,对一个 30B 参数的稠密模型执行一次 generate step,batch size 为 4 token,参数为 int8、FLOPs 为 bf16、上下文长度 8192,每个 token 的 KV cache 为 100 kB。这个操作的合理延迟下界是多少?如果我们想采样一个 256 token 的 batch,又会怎样?

点击这里查看答案。 **答案:** 在 int8 下,参数将占用 30e9 字节。按照给定配置,每份 KV cache 将使用 `100e3 * 8192 = 819MB`。我们有 16 个芯片,每个芯片带宽为 `8.1e11` bytes/s,bf16 FLOPs 为 `1.97e14`。根据上面的公式,由于 batch size 很小,我们预期 step time 至少是 `(4 * 819e6 + 30e9) / (16 * 8.1e11) = 2.5 ms`。而在 256 token 时,我们的 MLP block 会明显进入 compute-bound 区间,因此 step time 约为 `(256 * 819e6) / (16 * 8.1e11) + (2 * 256 * 30e9) / (16 * 1.97e14) = 21ms`。

正如你所见,这里存在一个清晰的吞吐-延迟权衡。小 batch 很快,但硬件利用不充分;大 batch 较慢,但效率更高。下面是针对一些较老 PaLM 模型计算出的延迟-吞吐帕累托前沿(来自 ESTI 论文):

<b>图:</b>若干 PaLM 模型在成本(即吞吐)与延迟之间的帕累托前沿。注意,芯片数(C)和 batch size(B)会让你沿着帕累托前沿移动;唯一的例外是绿色点(PaLM 540B,C:32 B:16),在那里可用内存限制了可支持的 batch size,导致吞吐受损。还要注意,吞吐通常会在 batch size 接近 240 后逐渐趋于平缓。int8 权重可以带来更优的延迟-吞吐帕累托最优解,但并不能带来更高的最大吞吐。
图:若干 PaLM 模型在成本(即吞吐)与延迟之间的帕累托前沿。注意,芯片数(C)和 batch size(B)会让你沿着帕累托前沿移动;唯一的例外是绿色点(PaLM 540B,C:32 B:16),在那里可用内存限制了可支持的 batch size,导致吞吐受损。还要注意,吞吐通常会在 batch size 接近 240 后逐渐趋于平缓。int8 权重可以带来更优的延迟-吞吐帕累托最优解,但并不能带来更高的最大吞吐。

我们不仅可以通过调 batch size 来权衡延迟与吞吐;如果 HBM 成为限制因素,我们还可能偏好更大的拓扑,而不是更小的拓扑,以容纳更大的 batch。下一节会更详细地讨论这一点。

**要点:** 如果你关心 generation 吞吐,就应尽可能使用最大的单芯片 batch size。任何高于 TPU 算术强度阈值($B_\text{crit}$,通常是 120 或 240)的单芯片 batch size,都会使吞吐最大化。你可能需要扩大拓扑规模才能做到这一点。更小的 batch size 则可以用吞吐换更低的延迟。

从硬件角度看,这里还有一些需要补充说明的 caveat。点击查看。 上面的分析都比较理论化。实践中,我们往往看不到那么锋利的 roofline,原因有几点: * 我们假设 HBM 读取能与 FLOPs 完美重叠,但这并不现实,因为编译器(XLA)并非无所不能。 * 对于切分后的模型,XLA 也经常无法把模型切分矩阵乘中的 ICI 通信与 FLOPs 本身高效重叠,因此在线性层上,我们往往会在 $$\text{BS}=32$$ 以上就开始承受额外延迟。 * 即使 batch size 超过理论 roofline,由于重叠并不完美,吞吐通常仍会继续略有提升,不过这个经验法则依然很有参考价值。

内存呢?

前面我们花了不少时间讨论带宽和 FLOPs,却还没认真讨论内存。由于引入了新的数据结构 KV cache,推理时的内存图景与训练时非常不同。为了说明差异,这里选一个真实模型(LLaMA 2-13B)作为例子:

hyperparam value
L (num_layers) 40
D (d_model) 5,120
F (ffw_dimension) 13,824
N (num_heads) 40
K (num_kv_heads) 40
H (qkv_dim) 128
V (num_embeddings) 32,000

推理时到底是什么在占用内存?首先当然是参数。把它们算一遍,我们有:

param formula size (in bytes)
FFW params d_model2 x ffw_multiplier x 3(对应 SwiGLU 的 gate、up 和 down 投影)x n_layers 5,120 x 5,120 x 2.7 x 3 x 40 = 8.5e9
Vocab params 2(输入和输出 embedding)x n_embeddings x d_model 2 x 32,000 x 5,120 = 0.3e9
Attention params [2(q 和 output)x d_model x n_heads x d_qkv + 2(k 和 v)x d_model x n_kv_heads x d_qkv] x n_layers (2 x 5,120 x 40 x 128 + 2 x 5,120 x 40 x 128) x 40 = 4.2e9

把这些参数加起来,我们得到 8.5e9 + 4.2e9 + 0.3e9 = 13e9 个总参数,与预期一致。正如前几节所说,在训练时,我们可能会把参数存成 bfloat16,同时还保留 float32 的优化器状态。这可能会用掉大约 100GB 内存。相比之下,梯度检查点往往会占用数 TB,远远更夸张。

那么推理有什么不同? 在推理时,我们只存一份参数,假设也是 bfloat16。这会占用 26GB;而实践中,通过量化通常还能做得更好。这里没有优化器状态,也没有梯度。由于我们不会做 checkpoint(即为了反向传播而保留激活),所以激活占用在 prefill尤其是借助 Flash Attention,它避免了注意力矩阵的显式物化 和 generate 两个阶段都可以忽略不计。比如,prefill 8k token 时,单个激活只需要大约 8,192 x 5,120 x 2 bytes = 80MB 的内存。更长的 prefill 也可以拆成许多更小的前向传播,因此长上下文并不会构成大问题。Generation 用到的 token 更少,所以激活更可以忽略。

真正的主要区别是 KV cache。它保存了所有历史 token 的 key 和 value 投影,其大小只受允许的最大序列长度限制。对 $$T$$ 个 token 来说,总大小为

$$\text{KV cache size} = 2 \cdot \text{bytes per float} \cdot H \cdot K \cdot L \cdot T$$

其中 $$H$$ 是每个 head 的维度,$$K$$ 是 KV heads 的数量,$$L$$ 是层数,而前面的 2 是因为要同时存储 key 和 value。

这个东西很快就会变得非常大,哪怕 batch size 和上下文长度都只是中等水平。对 LLaMA-13B 来说,单条 8192 长度序列在 bf16 下的 KV cache 大小为

$$8192\ (T) \times 40\ (K) \times 128\ (H) \times 40\ (L) \times 2\ (\text{bytes}) \times 2 = 6.7 \text{GB}$$

仅仅 4 份这样的 KV cache,就已经超过参数本身的内存占用了! 需要说明的是,LLaMA 2 并不是针对长上下文 KV cache 大小做过优化的模型(事情不总是这么糟,因为通常 $K$ 会小得多,比如 LLaMA-3 就是如此),但这个例子依然很能说明问题。我们绝不能在内存或延迟估算里忽略它们。

为 LLaMA 2-13B 建模吞吐与延迟

现在来看看:如果我们试图在 8xTPU v5e 上,以不同 batch size“完美高效”地做 generation,直到前面推导出的理论最大吞吐临界 batch size(240),会发生什么。

Batch Size 1 8 16 32 64 240
KV Cache Memory (GiB) 6.7 53.6 107.2 214.4 428.8 1608
Total Memory (GiB) 32.7 79.6 133.2 240.4 454.8 1634
Theoretical Step Time (ms) 4.98 12.13 20.30 36.65 69.33 249.09
Theoretical Throughput (tokens/s) 200.61 659.30 787.99 873.21 923.13 963.53

8x TPU v5e 提供 128GiB HBM、6.5TiB/s 的 HBM 带宽(每片 0.82TiB/s)和 1600TF/s 的算力。

对这个模型而言,增大 batch size 的确会提高吞吐,但收益递减得非常快。batch size 超过 16 就会 OOM,而想接近 240,则需要多出一个数量级的内存。更大的拓扑可以改善延迟,但我们已经撞上了单芯片吞吐的墙。

假设我们保持总参数量不变,但通过“魔法”把 KV cache 缩小 5 倍(比如使用 1:5 的 GMQA,也就是 40 个 Q heads 共享 8 个 KV heads,更多细节见下一节)。

Batch Size 1 8 16 32 64 240
KV Cache Memory (GiB) 1.34 10.72 21.44 42.88 85.76 321.6
Total Memory (GiB) 27.34 36.72 47.44 68.88 111.76 347.6
Theoretical Step Time (ms) 4.17 5.60 7.23 10.50 17.04 52.99
Theoretical Throughput (tokens/s) 239.94 1,429.19 2,212.48 3,047.62 3,756.62 4,529.34

KV cache 更小之后,虽然收益递减仍然存在,但单芯片理论吞吐会一直增长到 batch size 240。我们可以塞下大得多的 batch 64,而且所有 batch size 下的延迟也都更好。延迟、最大吞吐,以及最大可容纳 batch size 都有了显著改善!事实上,后续几代 LLaMA 就用了这个优化,例如 LLaMA-3 8B 具有 32 个 query heads 和 8 个 KV heads(source)。

**要点:** 除了参数之外,KV cache 的大小对模型最终的推理性能也有极大影响。我们希望通过架构设计与运行时优化的组合,把它控制在合理范围内。

提升生成吞吐与延迟的技巧

自最初的 Attention is All You Need 论文以来,人们提出了许多让模型更高效的技术,其中很多都专门瞄准 KV cache。总体来说,更小的 KV cache 会让我们更容易在不伤害延迟的前提下,增大 generation 阶段的 batch size 和上下文长度,也会减轻 Transformer 周边系统(比如请求缓存)的压力。暂且不考虑对模型质量的影响,我们可能会看到:

Grouped multi-query attention(又名 GMQA、GQA): 我们可以减少 KV head 的数量,并在注意力机制中让多个 Q heads 共享它们。极端情况下,甚至可以让所有 Q heads 共用一个 KV head。相对于纯 MHA,这会按 Q:KV 的比例缩小 KV cache,而实践中人们观察到,模型性能对这种改动通常并不敏感。

assets/img/gmqa.png

这实际上也提高了 attention 计算的算术强度(参见第 4 节中的问题 4)。

混入一些局部注意力层: 局部注意力会把上下文限制在一个较小到中等长度的窗口内。在训练和 prefill 时,这意味着把注意力矩阵从三角形 mask 成一条对角带。这样一来,本地注意力层的 KV cache 最大长度就被有效限制住了。如果模型中既有一些全局层,又混入一些局部层,那么当上下文长度超过局部窗口后,KV cache 的整体大小就会显著下降。

跨层共享 KV: 模型可以学会按某种模式在不同层之间共享相同的 KV cache。虽然这确实能减少 KV cache 大小,并在增大 batch size、缓存、离线存储等方面带来好处,但共享的 KV cache 可能需要从 HBM 中被读取多次,因此它并不一定能改善 step time。


 <b>左:</b>多层纯全局注意力。<b>右:</b>一种在全局 / 局部层间交错,并与相邻层共享的示例模式。来源:<a href="https://research.character.ai/optimizing-inference/?ref=blog.character.ai">Character.ai 博客</a>.
左:多层纯全局注意力。右:一种在全局 / 局部层间交错,并与相邻层共享的示例模式。来源:Character.ai 博客.

量化: 推理通常对参数和 KV 的精度没有那么敏感。通过把参数和 KV cache 量化(例如 int8、int4、fp8 等),我们可以同时减少两者的内存带宽开销、降低达到 compute roofline 所需的 batch size,并节省内存以支持更大的 batch size。量化还有一个额外优点:即使模型训练时没有考虑量化,也常常可以在训练后直接应用。

使用 ragged HBM reads 与 Paged Attention: 前面的计算里我们为每个 KV cache 都分配了 8k 上下文,但实际上,往往没有必要把整份 KV cache 都从内存里读出来。请求长度分布通常很广,而且很多请求并不会真正用满模型的最大上下文,因此我们常常可以实现一些 kernel(例如 Flash Attention 的变体),只读取 KV cache 中非 padding 的部分。

Paged Attention 是在此基础上的进一步改进,它用类似操作系统页表的方式存储 KV cache,基本上避免了对 KV cache 进行整块 padding。这会增加不少复杂度,但也意味着每个 batch 只占用自己真正需要的那部分内存。它是一种运行时优化,因此同样与具体架构无关。

<b>图:</b>在 generation 过程中,一个 token("forth")会访问多个 KV cache block/page。通过对 KV cache 进行分页,我们避免加载或存储超过实际需要的内存。图引自 <a href="https://arxiv.org/pdf/2309.06180">PagedAttention 论文</a>。
图:在 generation 过程中,一个 token("forth")会访问多个 KV cache block/page。通过对 KV cache 进行分页,我们避免加载或存储超过实际需要的内存。图引自 PagedAttention 论文

**整体来看:** 与标准 MHA Transformer 相比,这些 KV cache 优化合起来可以把 KV cache 大小缩小一个数量级以上。这能让 Transformer 的整体成本也获得一个数量级级别的改善。

将推理分布到多个加速器上

到目前为止,我们一直在模糊处理如何把推理扩展到单芯片之外。沿着第 5 节的思路,我们来看看有哪些可选策略,以及它们各自的权衡。和前面一样,我们会分别讨论 prefill 和 generation。

Prefill

从 roofline 角度看,prefill 几乎与训练完全一样,因此几乎同样的技术和权衡都成立,包括模型(Megatron)并行、序列切分(在上下文足够长时)、流水线,甚至 FSDP 都是可行的!你唯一要做的,就是把 KV 保留下来,以便后面做 generation。和训练一样,增加芯片数量会带来更多 FLOPs/s(从而可能降低 TTFT),但也会引入通信开销(从而可能降低单芯片吞吐)。

切分 prefill 的通用规则: 下面给出一套普遍适用的 prefill 规则。我们先假设只对单条序列做 prefill(没有 batch 维):

  1. 模型切分: 我们通常会先做一定程度的模型并行,直到达到 ICI 受限点。正如第 5 节所说,对单一轴来说,这大约是 $F / 2200$(通常约为 4 到 8 路切分)。
  2. 序列并行: 超过这个点后,就做序列并行(类似数据并行,但切在序列维上)。虽然序列并行会给 attention 带来额外通信,但在更长上下文下,这一额外成本通常相当小。与训练时类似,我们也可以把通信和计算重叠起来(Megatron 使用 collective matmuls,ring attention 也是类似思路)。

**要点:** 在 prefill 阶段,凡是训练时能正常工作的切分方式,几乎都能很好地用于推理。先把模型并行做到 ICI 受限点,再做序列并行。

Generation

Generation 比 prefill 复杂得多。首先,我们很难拿到很大的 batch size,因为必须把很多请求凑在一起。其次,延迟目标也更低。这两点叠加起来,意味着 generation 通常更偏向 memory-bound,也更容易被通信开销拖累,因此可选的切分策略受到很大限制:

  1. FSDP 不可能: 因为在把参数和 KV cache 从 HBM 读到 MXU 时,我们已经是 memory-bound 了,所以绝不希望再通过 ICI 来搬运它们,毕竟 ICI 比 HBM 慢了好几个数量级。我们想搬的是激活,而不是权重。 这意味着类似 FSDP 的方法在 generation 中通常完全不可行。训练后不小心把它留着没关,是导致性能退化一个数量级的常见坑。

  2. 没有理由做数据并行: 纯数据并行没有帮助,因为它只会复制参数,并不能让我们更快地加载参数。你还不如直接多起几个模型副本。这里的意思是:起多个小 batch size 的独立服务副本。从模型层面做数据并行是严格更差的。

  3. 没有 sequence,就没有 sequence sharding。 祝你 sequence sharding 好运吧。

因此,对稠密模型的 generation 来说,基本只剩下各种模型切分的变体。 和 prefill 一样,最简单的做法是直接做模型并行(激活完全复制,MLP 权重沿 hidden 维完全切分),通常做到 4 到 8 路,直到 ICI 受限。但由于我们往往受限于内存带宽,其实还可以继续超过这个界限,以换取更低延迟!

关于 generation 中的 ICI 界限: 训练时,我们希望自己是 compute-bound,因此 roofline 关注的是 ICI 通信何时会比 FLOPs 更慢。然而在 generation 中,如果我们受限于参数加载的内存带宽,那么即使继续增加模型切分、跨过这个点,也可能在几乎不伤害吞吐(tokens/sec/chip)的前提下改善延迟。更多模型切分意味着更多 HBM 通道可用来加载权重,而 FLOPs 反而不再关键。这里的意思是:既然 FLOPs 时间不是瓶颈,我们真正需要担心的是 ICI 时间会不会超过参数加载时间。 下面来看看,在它成为瓶颈之前,我们最多能做多少模型并行。

$$\begin{align*}T_\text{HBM comms} = \frac{2DF}{Y \cdot W_\text{hbm}} && T_\text{ICI comms} = \frac{2BD}{W_\text{ici}}\end{align*}$$ $$T_\text{ICI comms} > T_\text{HBM comms} \rightarrow \frac{W_\text{hbm}}{W_\text{ici}} > \frac{F}{Y \cdot B} \rightarrow Y > F / (B \cdot \beta)$$

其中 $\beta = W_\text{hbm} / W_\text{ici}$。对 TPU v5e 和 TPU v6e,这个数通常约为 8。也就是说,若 $F$ 为 16,384、$B$ 为 32,理论上我们可以把模型并行做到 16384 / (32 * 8) = 64 路,而不会明显损伤吞吐。不过这假设我们也能把 KV cache 完整地切成 64 份,这其实并不容易,下面会展开。

在 attention 层里,我们也会像 Megatron 那样沿 head 维对 $$W_Q$$ 和 $$W_O$$ 做模型切分。KV 权重本身很小,往往复制它们反而比把它们切到超过 $K$ 路更划算。

**要点:** generation 阶段我们真正的选择,基本只剩各种模型并行变体。我们希望移动的是激活,而不是更大的 KV cache 或参数。当 batch size 较大时,我们按 FLOPs-ICI 界限($F / \alpha$)来做模型并行;当 batch size 较小时,可以通过更激进的模型切分来改善延迟(代价是吞吐略有下降)。如果想做的模型切分路数超过了 KV heads 数量,那么也可以沿 batch 维来切分 KV。

KV cache 的切分

我们还有一个额外的数据结构也必须切分,那就是 KV cache。 同样地,我们几乎总是希望避免复制 cache,因为它是 attention 延迟的主要来源。为此,我们首先像 Megatron 那样,沿 head 维切分 KVs。这最多只能做到 $K$ 路切分,所以对于 head 数较少的模型,我们会先尽可能多地切 head 维,然后再沿 batch 维切,也就是得到 $\text{KV}[2, B_Z, S, K_Y, H]$。这意味着 KV cache 会被完全分布式存储。

<b>图:</b>注意力机制的比较:(a)纯模型切分下的多头注意力;(b)对 KV cache 进行 batch 维切分时的多查询注意力。注意这里需要额外两次 AllToAll,把激活从模型切分切换到 batch 切分,这样它们才能作用于 KV cache。
图:注意力机制的比较:(a)纯模型切分下的多头注意力;(b)对 KV cache 进行 batch 维切分时的多查询注意力。注意这里需要额外两次 AllToAll,把激活从模型切分切换到 batch 切分,这样它们才能作用于 KV cache。

这样做的代价是:每一层 attention 都要多做两次 AllToAll。一次把 Q 激活转成 batch 切分,好让我们在 batch 切分下计算 attention;另一次再把 batch 切分后的 attention 输出转回纯模型切分。

下面给出完整算法! 这里我们把同时沿 $Y$ 和 $Z$ 维做模型并行的完整 attention 算法写出来。抱歉,下面我会同时用 $K$ 表示 key 张量和 KV head 维度。令 $M=N/K$。
1. X[B, D] = ...(来自上一层的现有激活,未切分) 2. K[BZ, S, KY, H], V[BZ, S, KY, H] = ...(现有 KV cache,按 batch 切分) 3. Q[B, NYZ, H] = X[B, D] \* WQ[D, NYZ, H] 4. Q[BZ, NY, H] = **AllToAll**Z->B(Q[B, NYZ, H]) 5. Q[BZ, KY, M, H] = **Reshape**(Q[BZ, NY, H]) 6. O[BZ, S, KY, M] = Q[BZ, KY, M, H] \*H K[BZ, S, KY, H] 7. O[BZ, S, KY, M] = **Softmax**S(O[BZ, S, KY, M]) 8. O[BZ, KY, M, H] = O[BZ, S, KY, M] \*S V[BZ, S, KY, H] 9. O[B, KY, MZ, H] = **AllToAll**Z->M(O[BZ, KY, M, H]) 10. O[B, NYZ, H] = **Reshape**(O[B, KY, MZ, H]) 11. X[B, D] {UYZ} = WO[NYZ, H, D] \*N,H O[B, NYZ, H] 12. X[B, D] = **AllReduce**(X[B, D] { UYZ}) 这相当复杂,但你大体上能看出它是怎么工作的。新增通信虽然也有成本,但它们操作的是较小的激活;作为交换,我们节省了加载 KV(静止不动的大对象)所需的大量内存带宽。

设计一个高效的推理引擎

到目前为止,我们讨论的都是如何把单独的 prefill 与 generate 操作各自高效地优化和切分。要真正把它们用好,我们还需要设计一个推理引擎,让这两类操作在我们期望的延迟 / 吞吐帕累托前沿上的某个点持续喂饱。

最简单的方法,就是先跑一批 prefill,再跑一批 generation:

<b>图:</b>在最简单的配置下,请求被聚合起来,服务器先交替运行一批 prefills,然后反复调用 generate 函数,直到所有序列都完成。
图:在最简单的配置下,请求被聚合起来,服务器先交替运行一批 prefills,然后反复调用 generate 函数,直到所有序列都完成。

这很容易实现,也是大多数代码库中的第一版推理方案,但它有多个缺点:

  1. 延迟很糟糕。 我们把 prefill 和 generate 的 batch size 绑在了一起。prefill batch 很大时,首 token 时间(TTFT)会非常差,因为在任何用户看到 token 之前,你必须先把所有 prefill 都做完;而在小 batch 下,generate 吞吐又会很差。
  2. 短生成会被长生成阻塞。 许多序列会比其他序列更早结束,于是在 generation 过程中留下空的 batch 槽位,进一步伤害 generate 吞吐。batch 越大、生成越长,这个问题越严重。
  3. Prefill 需要 padding。 所有 prefill 都会被 pad 到最长序列,我们因此浪费了大量算力。虽然有解决方案,但历史上 XLA 让跳过这些 FLOPs 变得很困难。batch 越大、prefill 序列越长,这个问题越明显。
  4. 我们被迫让 prefill 和 generation 共享同一种切分。 两者都跑在同一 slice 上,这意味着二者必须使用相同的拓扑与切分(除非你保留两份权重),而这通常不利于性能。例如,generate 往往希望有更强的模型切分。

因此,这种方法只推荐用于边缘场景(通常只需要服务单个用户,而且硬件 FLOPs/byte 更低),或者 Transformer 代码库生命周期早期为了快速迭代而使用(因为它足够简单)。

稍微更好的办法是:prefill 保持 batch size 1(这时它依然是 compute-bound,但延迟合理),而在 generation 阶段把多个请求批在一起:

assets/img/interleaving.png

这样可以避免批量 prefill 带来的 TTFT 浪费,同时保持较高的 generation 吞吐。我们把这种配置称为 interleaved,因为它把 prefill 和 generation 步骤“交错”在一起。这对评测之类以吞吐为首要目标的大规模生成场景非常有力。调度器可以配置成:只要有 generation 槽位释放出来,就优先安排 prefill,从而即便在很大的 generation batch 下,也能保持高利用率。由于 prefill 不再和其他请求一起 batch,我们也能避免把它 pad 到最大长度。

它的主要缺点是:当服务器在执行某个 prefill 时,其他所有请求的 generation 都会暂停,因为 prefill 会吃掉所有计算资源。也就是说,正在解码响应的用户 A,会被正在做 prefill 的用户 B 阻塞。这意味着即便 TTFT 改善了,token 生成仍然会抖动,而且平均速度偏慢;对很多应用来说,这不是好的用户体验,因为别人的 prefill 会出现在你请求整体延迟的关键路径上。

为了解决这个问题,我们把 decode 和 prefill 分开。虽然 Transformer 推理可以放在一台服务器上完成,但从延迟角度看,更好的做法往往是让这两项不同任务分别跑在两组 TPU/GPU 上。Prefill 服务器生成 KV cache,然后通过网络发送给 generate 服务器;后者把多份 cache 批在一起,并为每个请求生成 token。我们把这称为 “解耦式(disaggregated)” 服务。

assets/img/disaggregation.png

这种做法有几个优点:

  1. 可扩展的低延迟: 一个用户的请求不会被另一个用户阻塞,除非 prefill 容量本身不足。请求应当立即被 prefill,然后发送到 generation 服务器,再立刻插入 generation buffer。如果我们预计会有很多并发请求进入,就可以独立扩展 prefill 服务器数量,而不需要同步扩展 generate 服务器数量,从而避免用户在 prefill 队列里等待太久。

  2. 专门化: 很多时候,prefill 和 generate 在延迟最优参数切分策略 / 硬件拓扑上差异很大(例如,generate 常常更需要模型并行,而 prefill 不需要)。强行让两者使用相同切分会同时伤害两者的性能,而保留两套权重又浪费内存。此外,把 prefill 挪到独立服务器后,它不再需要持有除当前正在处理那条之外的任何 KV cache。这样一来,我们就能腾出更多内存用于历史缓存(见下一节)或进一步优化 prefill 延迟。

一个缺点是:现在 KV cache 需要通过网络搬运。这通常是可接受的,但它再次说明了为什么我们希望 KV cache 尽可能小。

**要点:** 对于既关注延迟又追求高吞吐的服务场景,我们通常必须把 prefill 和 generation 拆到不同服务器上:prefill 用 batch 1 运行,而 generation 则把许多并发请求批在一起。

连续批处理

上面问题(2)引出了 continuous batching(连续批处理) 的概念。我们会优化并编译:

然后,我们用一个 orchestrator 把这两个函数组合起来。它负责排队进入的请求,根据当前可用的 generation 槽位决定调用 prefill 还是 generate,管理历史缓存(见下一节),并把生成出的 token 流式返回出去。

assets/img/continuous-batching.gif

前缀缓存

既然 prefill 既昂贵又是 compute-bound(这意味着我们可腾挪空间更小),那么降低它成本的最佳方法之一,就是少做一些 prefill。由于 LLM 是自回归模型,查询 ["I", "like", "dogs"] 和 ["I", "like", "cats"] 在前两个 token 上会产生完全相同的 KV cache。也就是说,从原则上讲,如果我们先算出 “I like dogs” 的 cache,再算 “I like cats” 的 cache,那么第二次只需要做原来 1 / 3 的计算。通过复用 cache,我们可以省掉大部分工作。这在几个特定场景中特别有用:

  1. 聊天机器人: 大多数聊天对话都是严格地在前文基础上追加内容。这意味着,如果我们能保存每一轮对话的 KV cache,那么后续轮次就只需要为新追加的 token 做计算。
  2. 少样本提示: 任何 few-shot prompt 都可以被保存下来并免费复用。系统提示词通常也符合这种模式。

它之所以难做,唯一的根本原因在于内存限制。正如我们已经看到的,KV cache 很大(往往是数 GB),而要让缓存真正有用,我们必须把它们保留到后续请求到来为止。通常,prefill 服务器上任何未被使用的 HBM 都可以拿来做本地缓存系统。此外,加速器的 CPU host 通常也有很多内存(例如一台 8xTPUv5e 服务器有 128GiB 的 HBM,但大约有 450GiB 的 Host DRAM)。这部分内存比 HBM 慢得多,通常慢到不足以支撑 generation step,但对于缓存读取来说已经够快。实践中:

<b>图:</b>采用 LRU trie 实现的 KV 前缀缓存。通过共享前缀,我们可以避免重复占用 KV 内存。来源:<a href="https://research.character.ai/optimizing-inference/?ref=blog.character.ai">Character.ai 博客</a>。
图:采用 LRU trie 实现的 KV 前缀缓存。通过共享前缀,我们可以避免重复占用 KV 内存。来源:Character.ai 博客

来看一个实现:JetStream

Google 开源了一个实现这套逻辑的库,叫做 JetStream。这个服务器由一组 “prefill engines” 和 “generate engines” 组成,通常分布在不同的 TPU slice 上,并由单个控制器统一编排。Prefill 发生在“prefill thread”中,而 generation 发生在“generate thread”中。它还包含一个“transfer thread”,负责把 KV cache 从 prefill slice 复制到 generate slice。

Engine 接口(实现见这里)是一个任意 LLM 都必须提供的通用接口。关键方法包括:

我们也提供了 JetStream 的 PyTorch 版本,见这里

练习题

这一节中,我会虚构一个基于 LLaMA-2 13B 的新模型。参数如下:

hyperparam value
L (num_layers) 64
D (d_model) 4,096
F (ffw_dimension) 16,384
N (num_heads) 32
K (num_kv_heads) 8
H (qkv_dim) 256
V (num_embeddings) 32,128

问题 1: 上述模型一共有多少参数?在 int8 下,它的每 token KV cache 有多大?你可以假设输入和输出投影矩阵是共享的。

点击这里查看答案。 **参数量:** * MLP 参数量:$L * D * F * 3$ * Attention 参数量:$L * 2 * D * H * (N + K)$ * 词表参数:$D * V$(因为这两个矩阵是共享的) 因此,总参数量为 $L * D * (3F + 2H * (N + K)) + D * V$。代入上面的数字,得到 `64 * 4096 * (3*16384 + 2 * 256 * (32 + 8)) + 4096 * 32128 = 18.4e9`。所以,这个模型大约有 184 亿参数。 在 int8 下,每个 token 的 KV cache 大小为 $2 * L * K * H$,也就是 `2 * 64 * 8 * 256 = 262kB`。

问题 2: 假设我们想在 TPUv5e 4x4 slice 上部署这个模型,并且能够在该拓扑上完全切分 KV cache。如果全部都使用 int8,并且要支持 128k 长序列,最大能容纳多大的 batch size?如果把 KV heads 数量降到 1,又会怎样?

点击这里查看答案。 在 int8 下,KV cache 每个 token 的大小为 $2 \cdot L \cdot K \cdot H$,即 `2 * 64 * 8 * 256 = 262kB`。对于 128k 序列,这意味着每个 batch 条目需要 `262e3 * 128e3 = 33.5GB`。每个 TPU 有 16GB HBM,还得包含参数,因此能容纳的最大 batch size 为 `(16 * 16e9 - 18.4e9) / 33.5e9 = 7`。如果 $K=1$,这个数会再乘以 8,也就是大约 56。

问题 3: 假设参数是 int8,并且它们在 TPU v5e 4x4 slice 上被完全切分。从 HBM 把所有参数载入 MXU 需要多久?这可以作为逐 step 延迟的一个很好下界。

点击这里查看答案。 我们共有 18.4B 参数,即 int8 下的 18.4e9 字节。每个芯片的 HBM 带宽为 8.1e11,因此如果能完全跑满 HBM 带宽,所需时间约为 `18e9 / (8.1e11 * 16) = 1.3ms`。

问题 4: 假设我们想在 TPUv5e 4x4 slice 上,以 int8 FLOPs 以及 int8 参数 / 激活来部署这个模型。对于 prefill 和 decode,应该如何切分它?提示:也许可以先回答下面几个问题:

  1. 4x4 上的 ICI 拓扑是什么样的?
  2. tensor parallelism 的 roofline 界限是多少?
  3. KV cache 可以如何切分?

在这种切分方案下,generation 的大致逐 step 延迟是多少?

问题 5: 假设上面的模型其实是一个 MoE。MoE 模型本质上相当于有 E 份 FFW block 副本的稠密模型。每个 token 会通过其中 k 个 FFW block,并将这 k 个结果平均后作为输出。这里设定 E=16k=2,其余配置保持不变。

  1. 它一共有多少总参数、多少激活参数?激活参数指任意一个给定 token 实际会用到的参数。
  2. 在 TPU v5e 上,需要多大的 batch size 才会变成 FLOPs-bound?
  3. 它的每 token KV cache 有多大?
  4. 当输入有 T 个 token 时,一次前向传播包含多少 FLOPs?
点击这里查看答案。 (1) 作为 MoE,每个 MLP block 现在有 $3 * E * D * F$ 个参数,比稠密版本多了 $E$ 倍。因此总参数量变成 $L * D * (3EF + 2H * (N + K)) + D * V$,也就是 `64 * 4096 * (3*16*16384 + 2 * 256 * (32 + 8)) + 4096 * 32128 = 212e9`,大约增加了 12 倍。至于激活参数,用的是 $k$ 而不是 $E$,因此总量为 `64 * 4096 * (3*2*16384 + 2 * 256 * (32 + 8)) + 4096 * 32128 = 31.2e9`,相比稠密版本增加不到 2 倍。 (2) 因为参数数量增加了 $E$ 倍,但 FLOPs 只增加了 $k$ 倍,所以 HBM roofline 会放大 $E/k$ 倍。这意味着在 TPU v5e 上,我们大约需要 `240 * (16 / 2) = 1920` 个 token。 (3) KV cache 大小保持不变,因为 MoE 属性不会改变注意力机制本身。 (4) 仍然是 $2 \cdot \text{activated params} \cdot T$。因此就是 $2 * \text{31.2e9} * T$。

问题 6: 对于 MoE,我们可以做“expert sharding”,也就是沿网格的一个轴把 expert 切开。在我们的标准记号里,第一个 FFW 权重形状是 [E, D, F],我们把它切成 [EZ, DX, FY],其中 X 只在训练时作为 FSDP 维使用。现在假设我们想在 TPU v5e 上做推理:

  1. 对于上面的模型,在 TPU v5e 8x16 slice 上取 Y=8、Z=16 时,HBM 权重加载时间是多少?每个 TPU 还剩多少空闲 HBM?
  2. 能容纳这个模型的最小 slice 是什么?

问题 7 [2D 模型切分]: 这一题我们来推导 ESTI 论文中所谓的 2D weight-stationary sharding。我们会在附录 B 简单介绍它,但你可以先自己做这道题,看看能否先把数学推出来。2D weight stationary 的基本思想是:沿 $D$ 和 $F$ 两个轴同时切分权重,让每个分块尽量接近正方形。这样可以降低通信负担,并使模型能够稍微扩展得更远一些。

下面是 2D weight stationary 的算法:

1. In[B, DX] = **AllGather**YZ(In[B, DXYZ]) 2. Tmp[B, FYZ] {UX} = In[B, DX] \*D Win[DX, FYZ] 3. Tmp[B, FYZ] = **AllReduce**X(Tmp[B, FYZ] {UX}) 4. Out[B, DX] {UYZ} = Tmp[B, FYZ] \*F Wout[FYZ, DX] 5. Out[B, DXYZ] = **ReduceScatter**YZ(Out[B, DX] {UYZ})

你的目标是推导出该算法的 $T_\text{math}$ 和 $T_\text{comms}$,并找出它何时会优于传统的 3D 模型切分。

点击这里查看答案! 我们来推导 $T_\text{math}$ 和 $T_\text{comms}$。所有 FLOPs 都被完全切分了,所以和之前一样有 $T_\text{math} = 4BDF / (N \cdot C)$,而通信时间现在为 $$\begin{align*} T_\text{2D comms} = \frac{2BD}{2X \cdot W_\text{ici}} + \frac{4BF}{YZ \cdot W_\text{ici}} + \frac{2BD}{2X \cdot W_\text{ici}} = \frac{2BD}{X \cdot W_\text{ici}} + \frac{4BF}{YZ \cdot W_\text{ici}} \end{align*}$$ 这里注意,AllReduce 的代价是两倍,同时我们会按每个操作实际跨越的轴数来缩放通信量。假设我们可以自由选择拓扑,并且设 $F=4D$(如同 LLaMA-2),则可以证明(用一些基础微积分)最优的 $X$、$Y$、$Z$ 取值为 $X = \sqrt{N / 8}$、$YZ = \sqrt{8N}$,从而总通信量为 $$T_\text{2D comms} = \frac{2B}{W_\text{ici}} \left(\frac{D}{X} + \frac{8D}{YZ}\right) = \frac{\sqrt{128} BD}{\sqrt{N} \cdot W_\text{ici}} \approx \frac{11.3 BD}{\sqrt{N} \cdot W_\text{ici}}$$ 首先,回顾上文,普通 1D 模型并行的通信时间为 $T_\text{model parallel comms} = 4BD / (3 \cdot W_\text{ici})$。那么新的通信什么时候会更小? $$\begin{align*} T_\text{model parallel comms} > T_\text{2D comms} \iff \frac{4BD}{3 \cdot W_\text{ici}} > \frac{\sqrt{128} BD}{\sqrt{N} \cdot W_\text{ici}} \\ \iff N > 128 \cdot \left(\frac{3}{4}\right)^2 = 81 \end{align*}$$ 对于一般的 $F$,这个条件可写为 $$N > 32 \cdot \left(\frac{F}{D}\right) \cdot \left(\frac{3}{4}\right)^2$$ 这说明:如果芯片数超过 81,我们就更适合用这个新方案。这个结果稍微有点反直觉,因为从经验上看,历史上我们在大约 20 路 tensor parallelism 时就开始 ICI-bound 了。但这里,即便我们已经 communication-bound,总通信量仍然会随着总芯片数增加而继续下降!这说明我们还可以继续增加芯片、增加 batch size、做更大的参数扩展,并同时获得更低延迟。

第 7 部分到这里就结束了!如果你想继续阅读第 8 部分,看看我们如何在 TPU 上服务化部署 LLaMA 3,请点[这里](../applied-inference)。

附录

附录 A:batch size > 240 这条规则到底有多真实?

上面给出的简化规则是:batch size 必须大于 240 token,模型才会是 compute-bound。这个说法大体正确,但它忽略了 TPU 的一种能力:当其他操作没有用满 HBM 时(例如进行跨设备通信时),TPU 其实可以预取权重。

下面是一张经验图,展示了一个小型 Transformer 的 layer time(单位微秒)。这个模型的 dmodel 为 8192、dff 为 32768,每层只有 2 个 matmul。图来自这个 Colab 笔记本。你会看到,step time 在 batch 240 左右之前增长非常缓慢,而在那之后才开始线性增长。

assets/img/batch-scaling-latency.png

下面则是真实的吞吐(单位 tokens / us)。这一点说明得更清楚。由于这里的每层大约是 6 亿参数,并且按 4 路切分,我们预计最低延迟大约为 365us。

assets/img/batch-scaling-throughput.png

因此,至少在这个模型里,我们确实看到吞吐会随着每个 data parallel shard 的 batch size 增长,直到大约 BS240 才趋于饱和。

附录 B:2D Weight Stationary 切分

随着拓扑规模扩大,如果我们能用到更高维的 mesh(比如 TPU 的 mesh),就可以通过引入第二个切分轴,把这个方案进一步细化为“2D Weight Sharding”。我们把它称为“2D Weight Stationary”,并在 Efficiently Scaling Transformer Inference 论文中有更详细的描述。

因为在 Megatron 中,我们只沿隐藏维 $$F$$ 对权重进行切分,所以在 1D 切分、芯片数较大时,$$F$$ 可能会明显小于 $$E$$(即 $$d_\text{model}$$ 维)。这意味着,在更大的 batch size 下,把一部分 collective 放到 MLP 第一层之后、沿 hidden 维进行,反而可能更划算。

assets/img/2d-weight-stationary.png

这张图展示了:

  1. 1D weight-stationary sharding,也就是纯 Megatron 切分。此时激活在 AllGather 后会被完全复制,而权重则完全沿隐藏维 F 切分。
  2. 2D weight stationary sharding,此时权重同时沿隐藏维 F 和归约维 E 切分,而激活则沿 E 维切分。我们会在第一层之前沿 (yz) 轴做一次 AllGather,然后沿 (x) 轴做一次 ReduceScatter。

对于 attention 层,在芯片数量较少时,Megatron 风格切分也相对简单。但 Megatron 是沿 $$n_\text{heads}$$ 维切分的,这限制了最多能切多少路。如果把 2D 切分改造到 attention 上(也就是改为切 $$n_\text{heads}$$ 维,而不是隐藏维),我们就能继续往更大规模扩展。

附录 C:受延迟限制的通信

回顾一下,在第 3 节里,我们推导过:在一条 1D ring 上,链路为全双工带宽 WICI、固定延迟为 Tmin,跨 X 个芯片对每个 TPU 上大小为 B 的张量执行一次 AllGather,需要的时间是

$$T_{total} = \max\left(\frac{T_{min} \cdot |X|}{2}, \frac{B}{W_{ICI}}\right)$$

当 B 很大时,墙钟时间会相对稳定,因为随着系统中芯片数增加,一方面你需要移动的数据量在扩大,另一方面总可用带宽也在同步扩大。

assets/img/all-gather.gif

在以低延迟为目标的推理中,由于被移动的数据量通常较小,激活上的 collective 经常会被延迟项主导(尤其是 batch size 很小时)。要直观理解这个延迟,一个简单办法是数一数:整个 collective 完成前,到底需要经过多少跳(hop)。

在 TPU 上,如果通信中与张量大小相关的那一部分小于每跳 1 微秒(所谓一跳,就是两个相邻设备之间的一次通信),那么系统瓶颈就可能变成真正发起 collective 的固定开销。对于单向 4.5e10 的 ICI 带宽,当满足 $$(\text{bytes} / n_\text{shards}) / 4.5e10 < 1e-6$$ 时,ICI 通信就会进入延迟受限区间。对于 8 路 Megatron 切分,这相当于 buffer_size < 360kB而在推理中,这个数其实并不算小: 例如在 BS=16D=8192 且使用 int8 时,激活大小为 16*8192=131kB,也就是说我们已经是 latency-bound 了。

**要点:** 当 $$\text{total bytes} < W_{ICI} \times 1e-6$$ 时,通信就会变成延迟受限。例如,在沿 $$Y$$ 维做模型并行时,int8 下当 $$Y > BD / 45,000$$ 时,我们就会进入这一 regime。

这和计算 roofline 有一个有趣的类比:我们都在为一些很小的操作支付固定成本(通信中的延迟、matmul 中的内存带宽)。

附录 D:投机采样

当我们真的非常在意端到端延迟时,还可以使用一个额外技巧,叫做投机采样(speculative sampling)。回顾一下,通常我们会让一个大型 Transformer 逐个生成 token:

assets/img/spec-sampling1.png

而在投机采样中,我们先用一个更小、更便宜的模型来生成 token,再用大模型检查结果。对于贪心解码(greedy decoding),它最容易理解:

assets/img/spec-sampling2.png
  1. 我们先从某个更小、更便宜的模型中进行贪心采样。理想情况下,这个小模型是专门训练来拟合大模型的,例如通过蒸馏;但它也可以简单到只是基于 n-gram 或一个小文本语料做 token 匹配。
  2. 当我们生成了 K 个 token 后,用大模型去计算这一路已生成 token 的所有 next-token logits。
  3. 由于我们做的是贪心解码,只需检查小模型生成的 token 是否在所有候选 token 中概率最高即可。如果某个 token 错了,我们就取最长的正确前缀,把第一个错误 token 替换成正确 token,然后回到步骤 (1)。如果所有 token 都是对的,我们就可以利用最后一个正确的 logit 再额外采样一个 token,然后再回到步骤 (1)。

为什么这能降低延迟? 这种方案本质上仍然要求我们对大模型执行“每个 token 一次前向传播”的等价 FLOPs,但因为我们可以把许多 token 一起 batch 起来,所以能在一次前向传播里完成这些 FLOPs,并利用一个事实:我们其实不是 compute-bound,因此可以“免费”多给一些 token 打分。

每个被接受的 token,从平均 FLOPs 成本上看会变得更贵(因为总会有一些 token 被拒绝,而且我们还要调用一个 draft model),但我们从硬件里榨出了更多 FLOPs,而小模型又很便宜,所以总体上仍然划算。我们还可以在多个步骤之间共享 KV cache 的读取,因此在长上下文下,投机解码也可能提升吞吐。 由于所有结果都经过了大模型核验,采样分布本身完全不会被改变(尽管在非贪心采样下,具体轨迹会有所不同)。

传统上,投机解码依赖一个比目标模型更小、但采样分布相似的模型,例如用 LLaMA-2 2B 给 LLaMA-2 70B 做 drafter,而这种配套小模型并不总是存在。即便存在,如果接受率较低,这个小 drafter 仍可能太昂贵。另一种更实用的方法,是把 drafter 嵌入主模型内部,例如在基座模型的某一较后层上增加一个专门的 drafter head。由于这个 head 与主模型共享了大部分参数,因此运行更快,而且也更容易匹配主模型的采样分布。

对于普通自回归采样,token/s 与 step time 是同一回事。我们仍然受制于本章“算术强度”一节中推导出的理论最小 step time(事实上,投机采样的 step time 往往比普通自回归采样还要慢不少;但因为它平均每一步能吐出不止 1 个 token,所以 tokens/s 仍可能高得多)。

<b>图:</b>这张图展示了 Chinchilla(DeepMind 的一个 70B 模型)搭配 4B 参数 drafter(小模型)时,每一步的延迟与投机成功率。对 XSum(自然语言数据集)来说,理想的投机提前量约为 3 到 4 个 token;而 HumanEval(代码数据集)更可预测,因此更激进的投机会带来更大收益。
图:这张图展示了 Chinchilla(DeepMind 的一个 70B 模型)搭配 4B 参数 drafter(小模型)时,每一步的延迟与投机成功率。对 XSum(自然语言数据集)来说,理想的投机提前量约为 3 到 4 个 token;而 HumanEval(代码数据集)更可预测,因此更激进的投机会带来更大收益。

那非贪心解码怎么办? 这会更复杂一些,但本质上可归结为一种受 Metropolis-Hastings 启发的算法:我们根据 logits 得到 $$P_{\text{draft model}}(\text{chosen token})$$ 和 $$P_{\text{target model}}(\text{chosen token})$$,如果这两个概率之比小于某个阈值,就以概率方式拒绝该 token。

两篇论文几乎同时推导出了这一方法,并给出了很好的实践示例。

**要点:** 投机采样是另一个强有力的杠杆,可以用吞吐去换取更低的逐 token 延迟。不过在 batch size 受限(例如硬件规模较小或 KV cache 很大)的场景下,它甚至会变成双赢。

第 8 章

在 TPU 上服务 LLaMA 3-70B

本节将讨论服务 LLaMA-3 以及高效完成这件事所需的条件。和前一个 “applied” 章节一样,试着先拿起纸笔自己算出答案,再往下看!

LLaMA 的服务故事是什么?

先回顾一下 LLaMA 3-70B 的样子(可参考第 6 节):

超参数 数值
$$n_\text{layers}$$ (L) 80
$$d_\text{model}$$ (D) 8,192
$$d_{ff}$$ (F) 28,672
$$n_\text{heads}$$ (N) 64
$$n_\text{kv heads}$$ (K) 8
$$d_\text{qkv}$$ (H) 128
$$n_\text{embeddings}$$ (V) 128,256

先从一个简单问题开始:我们应该把服务部署在什么硬件上? 基本答案是,看哪种硬件的 FLOPs / 美元最便宜。这并不总是成立;有时更大的 HBM 或 ICI 带宽会比 FLOPs 更关键,但这是一个不错的启发式规则。 因此,我们通常会选择 TPU v5e,也就是我们当前专用的推理芯片(成本来自截至 2025 年 2 月的 Google Cloud pricing):

TPU 类型 bfloat16 FLOPs/s Google Cloud 美元 / 小时 FLOPs / $
H100 9.9e14 $10.8 3.3e17
v5p 4.59e14 $4.2 3.9e17
v5e 1.97e14 $1.2 5.8e17

每个 TPU v5e 有 16GB HBM,这意味着我们必须相当激进地对模型进行分片。先来思考几个可能重要的基本量:

问题: LLaMA 3-70B 的每 token KV 缓存有多大?你可以假设我们以 int8 存储它们。这决定了在给定拓扑上我们能用多大的 batch size。

想清楚之后点这里! LLaMA 3-70B 有 8 个 KV heads,因此每 token 的大小为 `2 * K * H * L = 2 * 8 * 128 * 80 = 160kB`。 **注意这个数字有多大!** 如果序列长度是 32k token(这很常见),那么会占用 `160e3 * 32,768 = 5.3GB / 序列`。对于 BS=240,这就是 1.3TB!由于 TPU v5e 每片只有 16GB,我们至少需要大约 `(70e9 + 1.3e12) / 16e9 = 86` 片 TPU v5e,才能仅仅把这些内存装下。还要注意,这和 70GB 的模型参数相比已经非常大了。

问题: 假设我们想以 batch size 32、序列长度 8192 来服务 L3 70B,并且参数和 KV 都用 int8。总共会使用多少内存?最小可以部署在哪个 slice 上?

答案 由于我们的 KV 以 int8 存储时大小是 `160e3` 字节,因此 KV 总内存为 `160e3 * 8192 * 32 = 41.9e9` 字节。参数是 `70e9` 字节,因为每个参数占 1 字节。所以总内存使用量是 `41.9e9 + 70e9 = 112GB`。 我们能使用的最小 slice 需要有 `112e9 / 16e9 = 7` 片 TPU,或者(向上取整到一个合理的偶数拓扑)TPU v5e `4x2`。这个配置会很紧张,考虑到其他额外开销,可能还是装不下,因此最少可能需要 `4x4`(或者降低 batch size)。

问题: 在这个 batch size 和量化配置下,如果部署在 TPU v5e 4x2,每个 decode step 的延迟大约是多少?吞吐量(tokens / sec / chip)是多少?如果换成 4x4 呢?假设我们以 bfloat16 执行 FLOPs,并且所有内容都完全分片。

答案 我们可以套用上一节中的公式: $$\begin{align*} \tiny \text{理论步时延(一般形式)} = \underbrace{\frac{\text{Batch Size} \times \text{KV Cache Size}}{\tiny \text{总内存带宽}}}_{\text{Attention(始终受带宽限制)}} + \underbrace{\max\left(\frac{2 \times \text{Batch Size} \times \text{Parameter Count}}{\text{总 FLOPs/s}}, \frac{\text{Parameter Size}}{\text{总内存带宽}}\right)}_{\tiny \text{MLP(可能受计算限制)}} \end{align*}$$ 这里我们的临界 batch size 大约是 120,因为参数是 int8,而 FLOPs 用的是 bfloat16。我们也可以手工计算右侧 `max` 里的两项,但那基本就是前面已经做过多次的计算。**所以我们已经深处于 matmul 和 FLOPs 都受内存带宽限制的区域。** 如果只看内存带宽,那么步时延基本就是 `(KV size + param size) / (8 * HBM bandwidth) = 112e9 / (8 * 8.1e11) = 17ms`。**所以理论上步时延大约是 17ms。** 吞吐量则是 `32 / .017 = 1882 tokens / sec`,或者 `1882 / 8 = 235 tokens / sec / chip`。 这里有一个额外注意点,就是需要检查我们的 matmul 会不会受 ICI 限制。这里我们可以给它分配 2 个轴,因此理论上只有当 $Y > 2 * F / 2200 = 2 * 28672 / 2200 = 26$ 时才会受 ICI 限制,所以完全没问题! 如果我们运行在 `4x4` 上,ICI 方面依然没问题,因此延迟会降到 `17 / 2 = 8.5ms`,但单芯片吞吐量保持不变。

思考吞吐量

我们花一点时间只考虑吞吐量。当我们优化吞吐量时,我们希望系统是计算受限的,也就是说尽量接近完全利用 TPU 的 MXU 算力。通常这意味着我们希望 batch size 尽可能大,这样每一步都能做尽可能多的工作。

问题: 在 TPU v5e 上,如果使用 bfloat16 权重和激活,我们的 batch size 需要多大才能让 matmul 进入计算受限?如果使用 int8 权重但 FLOPs 仍用 bfloat16 呢?如果是 int8 权重配 int8 FLOPs 呢?

答案 正如第 7 节所讨论的,对于任意一个满足 $B \ll D, F$ 的 bfloat16 matmul,我们有 $$\begin{equation*} T_\text{math} > T_\text{comms} \leftrightarrow \frac{2BDF}{2DF} \geq \frac{\text{TPU bfloat16 FLOPs/s}}{\text{HBM bandwidth}} = 240 \end{equation*}$$ 当我们的权重是 int8 时,分母会少一个 2,因此有 $2BDF / DF = 2B > 240$,也就是 $B > 120$,相当于把先前的临界 batch size 砍半。这对我们非常有帮助!如果我们使用 int8 权重和 int8 FLOPs,那么 TPU FLOPs/s 要用 int8 的数值,它会从 bfloat16 的 1.97e14 变成 3.94e14,几乎翻倍。这样一来我们又回到了起点,也就是大约需要 $B > 240$。 int8 权重配 bfloat16 FLOPs 是一种很常见的情况,因为无损地量化参数通常比做低精度算术更容易。

问题: 如果使用 bfloat16、int8、int4(KV 和参数都采用相同精度),并且上下文长度是 8k,那么最小可以在哪个 TPU v5e 拓扑上服务 LLaMA 3-70B?这题里你可以把 KV cache 当成可以忽略不计。

答案 这很简单!如果我们接受一个很小的 batch size,那么唯一限制就是参数内存能否装进 HBM,也就是 `ceil(num_params * sizeof(dtype) / HBM per TPU)`,再向上取整到一个合理拓扑(2 的某个倍数): | dtype | param size | KV size / token (bytes) | min TPU v5es | actual min slice | remaining HBM for KV caches | num KV caches @ 8k | | :---: | :--------: | :---------------------: | :----------: | :--------------: | :-------------------------: | :----------------: | | bf16 | 140GB | 324kB | 8.75 | 4x4 = 16 chips | 116 | 43 | | int8 | 70GB | 162kB | 4.38 | 4x2 = 8 chips | 58 | 43 | | int4 | 35GB | 81kB | 2.81 | 2x2 = 4 chips | 29 | 43 | 这非常酷!它告诉我们,如果愿意的话,我们确实可以把 LLaMA 70B 塞进 TPU v5e 2x2。只是你会注意到,能容纳的 KV cache 数量非常少。那其实就是 batch size!这意味着我们的 FLOPs 利用率会很糟。为了把 batch size 推高到 240,我们会非常乐意使用更大的拓扑。

问题: 假设我们总是使用这些拓扑所能容纳的最大 batch size,那么每个 generate step 的延迟大约是多少?

答案 这也很简单,因为我们选的 batch size 正好把所有 HBM 都填满了!这其实就是一个问题:把一整片 TPU v5e 这么多字节从 HBM 读入 MXU 需要多久?这就是 `v5e HBM / v5e HBM memory bandwidth = 16GB / 8.2e11 = 19ms`,所以大约是 **19ms / step**。如果我们假设 generation 的中位长度是 512 token,那么一次 decode 大约要 9 秒。注意,如果使用更小的 batch size,我们也许能获得稍微更好的延迟;例如如果只考虑 int4 的模型参数,那么最小时延大约是 10ms / step,因为 HBM 不再被填满。

**要点**:我们总可以通过问“把模型所有参数从 HBM 读到 MXU 需要多久”来给 decode 延迟做一个下界。当 KV cache 很小时,你可以把每一层想象成只是把对应的权重分块从 HBM 读出来,用完就丢掉。除非我们使用了很大的 batch size,或者存在大量跨设备通信,否则这通常是一个合理的下界(在 1.5 倍以内)。当 batch size 更大时,我们还需要建模 KV cache 的加载,因为那时它会压过参数加载。

同样地,在 FLOPs 受限的区域(例如训练或大 batch 推理),我们可以使用 $$\text{Total FLOPs} / (N \cdot C) = 2 \cdot \text{param count} \cdot B / (N \cdot C)$$ 这个下界,它假设没有通信。

问题: 对于上面每一种配置,这会给出怎样的单芯片吞吐量(queries / chip)?你可以假设 decode 长度的中位数是 512 token。

答案 这是一个重要问题,因为它和 cost / token 完全相关。 在我们关于 decode 长度中位数的假设下,吞吐量就是 $$B / (\text{per-step latency} \cdot \text{median steps} \cdot N) \approx 43 / (0.019 * 512 * N)$$。这大约等于 $$(4.42 / N)$$ QPS,因此代入 $$N$$ 可得: | dtype | QPS / chip | | :------: | :--------: | | bfloat16 | 0.27 | | int8 | 0.55 | | int4 | 1.11 | 请注意,这个估计相当乐观,因为它完全忽略了前向传播的工作内存(为激活和注意力分配的内存)。在使用 Flash Attention 时,这种假设不算离谱,但也并不现实。真实数字大概率只有这里的一半左右。要获得绝对最大吞吐量,我们很可能需要把芯片数量增加到两倍以上,同时显著增大 batch size。

问题: 如果我们把上述每个示例的拓扑都扩大一倍,峰值吞吐量会如何变化?

答案 如果我们在 bfloat16 下使用 4x8 slice,那么会剩余 372GB 可用于 KV cache,这允许我们把 batch size 提高到 140。由于 step time 保持不变,因此吞吐量会变成 `14.39 / num_chips`,也就是: | dtype | QPS / chip | | :---------------: | :--------: | | bfloat16 (on 4x8) | 0.44 | | int8 (on 4x4) | 0.90 | | int4 (on 2x4) | 1.80 | 如果进一步扩大,还会带来更大的收益!最大的结论是:**最小拓扑并不总是性能最好的拓扑**,特别是在 KV cache 大小成为瓶颈时。

问题: 现在让我们深入讨论分片问题。假设我们想在 TPU v5e 4x8 上以 bfloat16 来服务模型。那么在 generation 期间,我们会采用什么分片方式?能否避免进入通信受限?

答案 正如上一节讨论的,在 generation 期间,分片实际上只有一个选择:模型并行。那在进入通信受限之前,我们最多能做多少模型并行?如上一节所说,当 $$Y > \frac{F \cdot M_Y}{2200}$$ 时,模型大致会进入通信受限。对于 LLaMA 3-70B,`F = 28,672`,因此如果我们沿 2 个轴做模型分片,就会得到大约 $$Y = 28672 \cdot 2 / 2200 = 26$$,也就是说一般我们最多可以扩展到大约 16 片芯片而不进入通信受限,这允许我们使用 `4x4`,但不允许使用 `4x8`。一般来说,由于我们的计算与通信并不能完美重叠,即使这个估算也偏乐观。 **要点:我们实际上不能在 4x8 上通过纯模型并行来服务。** 在这里我们最理想的配置也只是 4x2,或者 _也许_ 4x4。 不过,正如我们讨论过的,当 batch size 很小时,我们往往可以使用更多模型并行,而不会显著损伤吞吐量,因为此时模型受的是内存带宽限制而不是 FLOPs 限制。我们之前说过,这个值大致是 $Y=F / (8\cdot B)$,因此如果 batch size 是 64,理论上我们可以扩展到 `Y = 28,672 / (8 * 64) = 56` 路模型并行之后才会受 ICI 限制。为了做一个 sanity check,我们可以看看单个 matmul 的 $T_\text{ici comms}$、$T_\text{hbm comms}$ 和 $T_\text{math}$。显然有: $$\begin{align*}T_\text{ici comms} = \frac{2BD}{W_\text{ici}} && T_\text{hbm comms} = \frac{2DF}{Y \cdot W_\text{hbm}} && T_\text{math} = \frac{2BDF}{Y \cdot C}\end{align*}$$ 对于 `4x8`,这会给出 $T_\text{ici comms}$ = `(2 * 64 * 8192) / 9e10 = 11us`,$T_\text{hbm comms}$ = `(2 * 8192 * 28,672) / (32 * 8.1e11) = 18us`,以及 $T_\text{math}$ = `(2 * 64 * 8192 * 28,672) / (32 * 1.97e14) = 4us`,因此理论上我们仍然是 HBM 带宽受限,这很好!*注意,从 `4x4` 扩展到 `4x8` 也许对吞吐量没有帮助,但它会降低延迟!* 如果我们看 int8 和 int4 配置,就**可以**用纯模型并行来做。也就是说,量化在这里带来的优势已经不只是 FLOPs 更快而已:它还允许我们在进入通信受限之前使用更大的 batch size。**因此,这个故事的结局是:我们无法在 4x8 上达到峰值吞吐,但对于 int8 和 int4 配置,我们可以使用纯模型并行。**

**提示**:有用的模型并行上限取决于 $$d_{ff}$$ 以及你沿多少个轴对模型做分片。这个最大值通常在 8 到 32 之间,具体取决于模型大小。你可以超过这个限制来降低延迟,但代价是吞吐量会有所损失。

Prefill 怎么办?

到目前为止,我们基本忽略了 prefill,因为它简单得多。现在让我们把几个概念拼起来,思考一下端到端的全景。

问题: 假设在 prefill 阶段我们能达到 40% 的 FLOPs 利用率。那么长度为 8192 的一次 prefill,在 16 片 TPU v5e 上需要多久?

答案 在 8k token 时,我们已经稳稳处于计算受限区域,因此只需要考虑 FLOPs。我们知道模型有 `70e9` 个参数,因此每次前向传播会使用 `2 * 70e9 * B` FLOPs。假设 MFU(FLOPs 利用率)为 40%,那么运行时间大约是 `2 * 70e9 * 8192 / (16 * 1.97e14 * 0.4) = 0.91s`。和前面我们一直看到的数字相比,这其实已经相当长了!

问题: 假设 prefill 长度中位数为 8192 token,decode 长度中位数为 4096 token。再假设 generate 的 batch size 为 32。平均来看,每一步会有多少个序列完成 decode?平均每一步会从 KV cache 中驱逐多少 token?

答案 这题其实很直接。由于 decode 长度中位数是 4096 token,因此一个序列大约每 1 / 4096 个 token 会完成一次。给定 batch size 为 32,这意味着我们每一步会有 `32 / 4096` 个序列被驱逐。由于 KV cache 的长度大约是 `8192 + 4096`,所以这意味着每一步会驱逐 `32 * (8192 + 4096) / 4096 = 96` 个 token。一般公式是 $B * (P + G) / G$,其中 $P$ 和 $G$ 分别是 prefill 和 generate 的长度。

问题: 假设我们采用解耦服务,prefill 长度中位数为 8192,decode 长度中位数为 512。再假设采用上面在 bfloat16 中算出的 prefill 和 generate 延迟。为了让两端都保持满载,你需要多大的 prefill:generate 服务器比例?

答案 这题挺有意思。设 $P$ 为 prefill 服务器数量,$G$ 为 generate 服务器数量。一般来说,这就是一个流水线问题:我们以 `P / prefill_latency` 的速率送入序列,以 `B * G / (generate_latency * median_decode_length)` 的速率消费序列。我们先前算出,prefill 每步是 `910ms`,而 batch size 43(这里把它近似成 32)时 decode 每步是 `19ms`。因此我们需要 `P / 0.91 = 32 * G / (0.019 * 512)`,也就是 `P = 3G`,换句话说,我们需要的 prefill 服务器大约是 generation 服务器的 3 倍!

延迟与吞吐量权衡的可视化

继续以 LLaMA 70B 为例,我们来实际看看 generation 阶段不同 batch size 下的延迟与吞吐量。正如我们在前一节对 PaLM 模型展示过的,这会给出一个吞吐/延迟的 Pareto 前沿。这里我们假设使用 16 路 tensor parallelism,因为这是在 MLP block 中保持计算受限时一个合理的上界。我们这里采用 TPU v5e 4x4 拓扑。滑块控制序列长度,因此你可以观察更大的 KV cache 带来的影响。

我们还可以通过把成本和延迟拆分为参数加载时间、KV 加载时间以及 FLOPs 时间,来更好地理解这件事。红色区域表示我们预计在 MLP block 中受计算限制的区域。

这张图讲述了一个非常清晰的故事。你会看到,在开始阶段,参数加载占据了绝大部分延迟;直到 batch size 足够大之后,FLOPs 和 KV 加载才开始变得更重要。尤其值得注意的是,在所有大于 2048 的序列长度下,我们花在 KV cache 加载上的时间都超过了 FLOPs 时间!因此,虽然通过增大 batch size 我们可以提升硬件利用率,但在长上下文长度下,KV 加载始终主导总步时延。

**要点:** 对于 LLaMA 3-70B 而言,在几乎所有这些配置下,我们都强烈地受到 KV cache 内存带宽限制(以及 HBM 限制),这也凸显了减少 KV cache 大小对 generation 吞吐量有多么重要。还要注意,延迟/吞吐量之间的权衡依然极其剧烈。

这段代码其实非常简单。 下面是用于计算这些屋顶线的代码:
import numpy as np

num_chips = 16  # 我们固定总模型并行为 16
bytes_per_param = 1  # int8 表示每个参数 1 字节
param_count = 70e9
param_size = bytes_per_param * param_count
sequence_length = 8192  # 这个值可以变化

hbm_bandwidth = 8.20E+11  # v5e
flops = 1.97E+14  # v5e

def kv_cache_size(bs):
    return 2 * bs * 128 * 8 * 80

def min_topology(bytes):
    return 2 ** np.ceil(np.log2(bytes / 16e9))

def get_max_batch_size(
    num_chips: int,
    sequence_length: int,
    param_size: float,
) -> int:
  batch_sizes = np.arange(1, 1024, 4)
  kv_sizes = kv_cache_size(sequence_length * batch_sizes)
  required_chips = min_topology(kv_sizes + param_size)
  max_idx = np.where(required_chips <= num_chips)[0][-1]
  return max_idx

max_idx = get_max_batch_size(
    num_chips=num_chips,
    sequence_length=sequence_length,
    param_size=param_size,
)  # 取能装下的最大 batch size
batch_sizes = np.arange(1, 512, 1)[:max_idx]
kv_sizes = kv_cache_size(sequence_length * batch_sizes)

kv_comms_time = kv_sizes / (num_chips * hbm_bandwidth)

param_comms_time = param_size / (num_chips * hbm_bandwidth)
param_comms_time = np.asarray([param_comms_time] * batch_sizes.shape[0])

flops_time = 2 * param_size * batch_sizes / (num_chips * flops)  # 在二维意义上近似成立

mlp_time = np.maximum(flops_time, param_comms_time)
attn_time = kv_comms_time  # generate 阶段始终受带宽限制

latency = 1000 * (mlp_time + attn_time)
throughput = batch_sizes / (latency * num_chips)
注意,我们在这里非常明确地把延迟拆成了两个来源:KV 加载和参数加载;同时,延迟是由 FLOPs 和通信中更大的那一个决定的。

习题详解

下面给出几道习题。有些内容会重复前文已经推导过的结论,但从教学角度看可能仍然有帮助。

问题 1: 对于 LLaMA 3-405B,每 token 的一次前向传播需要多少 FLOPs?假设我们受 FLOPs 限制,那么在 TPU v5e 上、使用 N 片芯片时,单次前向传播的时间下界是多少?如果我们受通信限制呢?忽略模型无法放进单芯片这一事实。

问题 2: 假设我们想以 BS240、int8 权重和 int8 KV cache 来服务 LLaMA 3-8B。那么 (a) 模型参数、(b) KV cache、(c) 峰值工作激活(粗略估计)分别会占用多少字节?最小可以跑在哪个拓扑上?

问题 3: 你会如何在 TPU v5e 上服务 LLaMA 3-405B?假设使用 int8 权重和 bfloat16 FLOPs。再假设我们有一个硬性约束:15ms / token,那么我们能实现的最高吞吐配置是什么?理论上的最小步时延是多少?

第 8 部分到这里就结束了!想阅读第 9 部分,深入了解 XLA 和 TPU profiling,请点击[这里](../profiling)。

第 9 章

如何进行 TPU 程序性能剖析

TPU 软件栈宏观概览

Google 提供了多种用于 TPU 编程的 API,从高级的 JAX 代码到低级的 Pallas 或 HLO。大多数程序员仅编写 JAX 代码,这使你可以编写抽象的、NumPy 风格的线性代数程序,这些程序会自动编译以在 TPU 上高效运行。

这里有一个简单的例子,一个将两个矩阵相乘的 JAX 程序:

import jax
import jax.numpy as jnp

def multiply(x, y):
  return jnp.einsum('bf,fd->db', x, y)

y = jax.jit(multiply)(jnp.ones((128, 256)), jnp.ones((256, 16), dtype=jnp.bfloat16))

通过调用 jax.jit,我们指示 JAX 追踪此函数并生成称为 StableHLO 的低级 IR(一种用于 ML 计算的平台无关 IR),然后它又被 XLA 编译器降低(lower)为 HLO。编译器会运行许多传递(passes)以确定算子融合、内存布局和其他因素,这些最终构成了可以在 JAX 性能分析中观察到的 HLO。此 HLO 在 LLVM 风格的图视图中表示了 JAX 代码中的所有核心线性代数操作(矩阵乘法、逐点运算、卷积等)。例如,这是上述程序作为 HLO 的删节版本要获取此 HLO,你可以运行 jax.jit(f).lower(*args, **kwargs).compile().as_text()

ENTRY %main.5 (Arg_0.1: f32[128,256], Arg_1.2: bf16[256,16]) -> f32[16,128] {
  %Arg_1.2 = bf16[256,16]{1,0} parameter(1), metadata={op_name="y"}
  %convert.3 = f32[256,16]{1,0} convert(bf16[256,16]{1,0} %Arg_1.2),
  %Arg_0.1 = f32[128,256]{1,0} parameter(0), metadata={op_name="x"}
  ROOT %dot.4 = f32[16,128]{1,0} dot(f32[256,16]{1,0} %convert.3, f32[128,256]{1,0} %Arg_0.1), lhs_contracting_dims={0}, rhs_contracting_dims={1},
}

我们稍后会解释 HLO 的语法,但现在只需注意它实际上与上面的 JAX 代码匹配得相当好。例如,

ROOT %dot.4 = f32[16,128]{1,0} dot(f32[256,16]{1,0} %convert.3, f32[128,256]{1,0} %Arg_0.1), lhs_contracting_dims={0}, rhs_contracting_dims={1}

正是上面的实际矩阵乘法,它分别沿着第 0 和第 1 维将两个 f32 矩阵相乘。

为了将此 HLO 转换为可在 TPU 上执行的代码,XLA 编译器首先将其降低(lower)为 LLO(低级优化器)IR。LLO 直接对 TPU 进行编程,调度内存之间的拷贝,将数组推入脉动阵列(systolic array)等。LLO 代码包含将缓冲区推入脉动阵列、提取结果,以及调度在 TPU 内存的不同部分之间进行通信的 DMA 原语。一旦它被降低为 LLO,接着就会被编译成机器码,加载到 TPU IMEM 中并执行。

当程序运行速度低于预期时,我们通常在 JAX 级别进行操作以提升性能。然而,这样做往往需要我们理解一些 HLO 的语义,以及代码在 TPU 上实际的运行方式。当更底层出现问题时,我们会使用另一个“逃生舱”,在 Pallas 中编写自定义内核。要查看程序的 HLO 及其运行时统计信息,我们使用 JAX profiler。

JAX Profiler:多用途 TPU 性能分析工具

JAX 提供了一个多用途的 TPU 性能分析器(profiler),其中包含许多有用的工具,用于了解程序运行时 TPU 上正在发生什么。你可以使用 jax.profiler 模块在程序运行时追踪它,并记录从每个子组件的持续时间、每个程序的 HLO,到内存使用情况等所有信息。例如,此代码将转储(dump)一个追踪文件到 /tmp/tensorboard 中,该文件可以在 TensorBoard 中查看(这里 是分步指南)。

import jax
with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (1024, 1024))
  y = x @ x
  y.block_until_ready()

# 现在你可以在 Google Colab 中加载 TensorBoard,通过:
#
# !pip install tensorboard tensorboard-plugin-profile
# %load_ext tensorboard
# %tensorboard --logdir=/tmp/tensorboard
#
# 或者在外部使用:
#
# > tensorboard --logdir=/tmp/tensorboard
#

以下是分析器功能的概览:

assets/img/xprof-overview.png

进入 TensorBoard 后,分析器有几个关键选项卡,可帮助你了解你的程序:

  1. 追踪视图 (Trace Viewer) 显示了 TPU 上实际发生操作的详细时间线。
  2. 图视图 (Graph Viewer) 显示了 HLO 图,让你可以看到程序的哪些部分相互馈送,以及事物是如何进行分片的(sharded)。
  3. 内存分析和内存视图 (Memory Profile and Memory Viewer): 这些显示了你的程序正在使用多少内存。

虽然分享分析结果有点困难,但这里有一个 Perfetto 链接,其中至少包含一个简单 Transformer 的追踪视图组件。这个 Colab 可以让你生成完整的 JAX/TensorBoard 追踪并进行探索。

追踪视图

追踪视图可能是分析器中最有用的部分。 下面的示例显示了一个带有部分标注的简单 Transformer。名称来自代码中提供的标签。

assets/img/trace-viewer.png

追踪视图显示了每个 TPU 核心上所有操作的按时间顺序排列的时间线。我们在这里只查看 TPU:0,因为通常所有 TPU 都执行相同的指令。几个关键点:

  1. 最上面一行 (XLA Ops) 显示了实际的 TPU 操作(名称即为 HLO 名称)。其他一切都是基于 jax.named_scopejax.named_call 和 Python 堆栈追踪的近似追踪。
  2. 注意重复的块,我们可以在此处隔离出单个网络层。我们还可以(通过查看代码/了解 Transformer 的工作原理)看到哪些部分是注意力机制 (attention),哪些部分是 MLP。
  3. 通过点击一个 XLA 操作,我们可以查看它来自代码的哪个位置(这对于理解追踪很有用),并看到指向图视图的链接。

**提示:** 你可以使用“电子游戏”风格的控件来浏览追踪视图,使用 A/D 键向左和向右平移,使用 W/S 键放大和缩小。这些控件使导航变得更加容易。

如何阅读 XLA 算子

HLO 实际上并不难读,它对理解上述追踪中给定部分对应的操作非常有帮助。这是一个名为 fusion.3 的算子示例。

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)} fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32), kind=kCustom, calls=%all-reduce-scatter.3

让我们将其分解为几个部分。

让我们尝试更深入地理解这个符号。以这个简单的例子为例:

f32[3,5]{1,0:T(2,2)}

它再次告诉我们,此算子返回一个形状为 [3, 5] 的 float32 数组,并具有特定的平铺策略 (tiling) {1,0:T(2,2)}。虽然平铺策略并非绝对重要,但简而言之,平铺策略告诉我们 N 维数组在内存中如何顺序排列。下面是显示该数组如何布局的图表:

assets/img/tiling.png

{1,0:T(2,2)} 中,1,0 部分告诉我们物理内存中数组维度的顺序,从最低阶到最高阶 (most minor to most major)。你可以从右到左阅读此部分,并在 f32[3,5] 中挑出相应的维度来确定数组的物理布局。在此示例中,物理布局为 [3,5],与逻辑形状相同。 之后,T(2,2) 告诉我们数组以 (2, 2) 的块 (chunks) 进行平铺,其中在每个块内,数组首先按行排列(行主序),然后按列排列,即 (0, 0) 之后是 (0, 1),然后是 (1, 0)(1,1)。由于 T(2, 2) 的平铺,数组被填充为 [4, 6],将其内存占用扩大了约 1.6 倍。对于上面给出的那个大型 bf16 数组,bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)},我们做 T(8,128)(2,1),这告诉我们该数组有两层平铺,即外部的 (8, 128) 平铺,和该单元内部的 (2, 1) 内部平铺(用于 bf16,因此我们的加载总是 4 字节的倍数)。例如,这是 bf16[4,8]{1,0:T(2,4)(2,1)}(颜色是 (2,4) 块,红色框是 (2,1) 块):

assets/img/tiling2.png

平铺策略会影响将张量块加载到 VMEM 的效率,XLA 有时会引入拷贝以在程序内部“重新平铺 (retile)”或“重新布局 (re-layout)”张量,有时会带来不小的开销。JAX 提供了一项实验性功能来规避这个问题,它允许 XLA 为程序输入计算其“首选 (preferred)”布局。当你使用 jax.jit“即时 (just-in-time)”编译程序时,你通常传入“模拟 (mock)”输入,告诉 JAX 期望的形状和数据类型。这些通常也带有平铺信息,但可能并非最优。相反,你可以将输入布局指定为 AUTO,jax.jit 将返回被 jit 编译的程序所偏好的布局。然后你可以显式地以此布局加载张量,以避免在程序内引发拷贝。

图视图

虽然上面的一些融合看似复杂,但 XLA 图视图使它们更容易解析。例如,这里是一个相当复杂的融合视图:

assets/img/graph-viewer.png

盯着一堆 HLO 图看,并尝试将 HLO 操作映射到你要分析的代码上,是非常有帮助的。通过将鼠标悬停在框上,你通常会看到定义该函数的代码行。

查看一个真实(或类似真实)的分析示例

这个 Colab 有一个用于模拟 Transformer 的示例分析。如果你赶时间的话,这里是一个 Perfetto 链接,至少可以看到追踪视图。我比平时花了更多精力用 jax.named_scope 调用来标注追踪,这样你就可以识别发生了什么。

assets/img/transformer-xprof.png

看看这个分析文件,试着真正理解每一部分在做什么。让我们稍微分解一下,从 FFW(前馈网络)块开始:

assets/img/transformer-ffw.png

这里我们放大了 FFW 块。你会看到向上投影 (up-projection) 算子是一个融合(矩阵乘法),输入为 bf16[8, 1024, 8192]bf16[8192, 16384],输出为 bf16[8, 1024, 16384]。我知道(因为这是我写的代码)这是一个四路数据并行 (DP)、两路模型并行 (MP) 分片矩阵乘法的本地视图,所以我们实际在进行的操作是:

X: bf16[32, 1024, 8192] * Win: bf16[8192, 32768] -> Tmp: bf16[32, 1024, 32768]

我们预期这会花费多长时间? 首先,每个数据并行分片的批次大小 (batch size) 是 8 * 1024 = 8192,因此我们应该是受限于计算 (compute-bound) 的。这是在 8 个 TPUv2 核心上(可在 Google Colab 上免费使用),因此我们预计它会花费约 2 * 32 * 1024 * 8192 * 32768 / (23e12 * 8) = 95.6ms,这实际上与其花费的时间(96ms)完全一致。太棒了!这意味着我们获得了出色的 FLOPs 利用率!

那通信方面呢? 你会注意到在第二个矩阵乘法末尾隐藏的一个小融合。如果我们点击它,你会看到:

%fusion.1 = bf16[8,1024,4096]{2,1,0:T(8,128)(2,1)} fusion(bf16[8,1024,8192]{2,1,0:T(8,128)(2,1)} %fusion.31), kind=kCustom, calls=%all-reduce-scatter.1

它基本上是一个小型的 ReduceScatter(这是图视图);

assets/img/reduce-scatter-xprof.png

我们预计这会花费多长时间?在 TPUv2 4x2 上进行 ReduceScatter,在 1.2e11 的双向带宽上只需一跳 (hop)。该数组的大小为 2*32*1024*8192,并在批次维度上分为四路分片,所以每个分片是 2*8*1024*8192=128MB。因此这应该需要大约 1.1ms。它实际花费了多长时间? 分析中报告为 1.13ms。所以我们非常接近 Roofline(理论上限)!

让我们也看看注意力机制 (attention)! 这是注意力机制组件的一个分析:

assets/img/attn-xprof.png

我点击了 Q 投影算子,它使用一个形状为 [dmodel = 8192, nheads = 32, dqkv = 256] 的矩阵 $$W_Q$$。我们正在沿着头 (head) 的维度使用 Megatron 分片策略。请尝试做同样的练习来计算这些操作应该花费多长时间。

内存分析

内存分析让程序内存在时间上的变化一目了然。这有助于调试内存溢出 (OOM) 的问题。你可以在这里看到大约有 7.5GB 分配给了模型参数,并且大约有 10GB 是空闲的。因此,我们可以在内存中容纳更多数据。

assets/img/memory-viewer.png

实战演练

问题 1:查看这个 Colab/分析,找出有什么可疑之处,以及发生了什么。你能告诉我究竟发生了哪些计算,每项操作具体在做什么吗?涉及的每个矩阵的真实形状是什么?它们是如何分片的?试着先只看分析,不看代码。

assets/img/all-reduce-profile.png
点击此处查看答案。 这是两个矩阵乘法,即具体如下:
def matmul(w1, w2, x):
  return jnp.einsum('wf,bf->bw', w2, jnp.einsum('fw,bw->bf', w1, x))
你可以看到一个规约 (reduce)、两个大型融合操作,以及一个全局规约 (all-reduce)。第一个大融合是: ```%fusion.1 = bf16[4096]{0:T(1024)(128)(2,1)} fusion(bf16[4096,8192]{1,0:T(8,128)(2,1)} %param.1, bf16[8192]{0:T(1024)(128)(2,1)} %reduce.6), kind=kLoop, calls=%fused_computation.1``` 它告诉我们每个分片的形状是 `bf16[8192] * bf16[4096, 8192] -> bf16[4096]`(在 8192 维度上)。通过观察具有 `replica_groups={{0,16,32,48,64,80,96,112}, ...}` 的最终 AllReduce,我们可以判断我们正在进行 8 路模型并行,因此真实的形状是 `[8, 8192] * bf16[32768, 8192] -> bf16[8, 32768]`。

问题 2: 前文提到的 Transformer Colab 实现了一个简单的模拟 Transformer。遵循 Colab 中的说明,获取一个带有 GSPMD 分区的朴素 Transformer 基准测试。每个部分花费了多长时间?应该花费多长时间?正在使用哪种分片策略?尝试修复这个分片!提示:使用 jax.lax.with_sharding_constraint 来约束其行为。经过此修复,你能获得的最佳 MXU 利用率是多少?

作为参考,最初版本的结果大约是 184 毫秒/层,优化后的分析结果是 67 毫秒/层。一旦你完成了这些,试着盯着分析图表看,看看你是否能纯粹从分析中回答这些问题:

注意: 自从写下这个问题以来,XLA 编译器变得更强了。最初的版本现在大约为 90 毫秒/层,优化后的分析图表大约缩短了 10 毫秒/层(80 毫秒/层)。尽管如此,它仍然值得一试,看看你是否能做得更好。

这就是第 9 部分的全部内容。关于第 10 部分的 JAX 并行深潜,请点击[这里](../jax-stuff)。

第 10 章

JAX TPU 编程

JAX 中的并行是如何工作的?

JAX 支持三种多设备编程流派:

  1. 编译器掌舵! 让 XLA 编译器自动对数组进行分区,并决定添加哪些通信来协助给定的程序。这让你可以将运行在单个设备上的程序自动运行在数千个设备上,而无需更改任何代码。
  2. JAX 掌舵! 自动并行很棒,但有时编译器会做出一些疯狂的事情。显式分片允许你像往常一样编写单设备代码,但由 JAX 处理分片传播(而不是编译器)。这意味着当你的意图不明确时,JAX 可以要求你进行澄清。
  3. 让我直接写出我的意图,该死! 虽然编译器很好,但它们有时会做错误的事情,并添加你并不想要的通信。有时我们希望明确指定我们打算运行哪些通信。
模式 视图 显式分片? 显式集合通信?
自动 (Auto) 全局
显式 (Explicit) 全局
手动 (Manual) 每个设备

相应地,JAX 为每种模式提供了 API:

  1. jax.jit(配合 Auto 网格轴)允许你获取任何现有的 JAX 函数并使用分片输入调用它。JAX 随后使用 XLA 的 Shardy 编译器,该编译器会自动并行化程序。当需要协助现有操作时,XLA 会为你添加通信(AllGathers、ReduceScatters、AllReduces 等)。虽然它并不完美,但它通常能很好地将你的程序自动扩展到任意数量的芯片,而无需更改代码。
  2. 带有 Explicit 网格轴的 jax.jit 看起来与 (1) 类似,但让 JAX 处理分片传播而不是 XLA。这意味着数组的分片实际上是 JAX 类型系统的一部分,当检测到模糊的通信时,JAX 可以报错并让用户解决它。
  3. jax.shard_map 是更手动的对应方案。你可以获得程序的设备局部视图,并且必须显式编写你想要的任何通信。有一个分片数组并希望每个设备上都有完整数组?添加一个 jax.lax.all_gather。想要跨设备对数组求和?添加一个 jax.lax.psum(一个 AllReduce)。编程更难,但极不可能做出你不想要的事情。

自动分片模式 (Auto sharding mode)

jax.jit 在 JAX 中扮演两个角色。顾名思义,它将函数从 Python “即时”编译为字节码(通过 XLA/HLO/LLO),从而使其运行得更快。但如果输入是分片的,或者用户指定了 in_shardingout_sharding,它还允许 XLA 在多个设备上分发计算并根据需要添加通信。例如,以下是如何使用 jax.jit 编写分片矩阵乘法:

import jax
import jax.numpy as jnp

# 在 TPU v5e 4x2 上运行。这为硬件的两个物理轴分配了名称。
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))

# 这告诉 JAX 将此网格用于所有操作,因此你只需指定分区规范 (PartitionSpec) P。
jax.set_mesh(mesh)

# 我们创建一个矩阵 W 和输入激活 In,它们在我们的设备上分片。
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('X', 'Y')))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('Y', None)))

def matmul_square(In, W):
  return jnp.einsum('bd,df->bf', jnp.square(In), W)

# 我们可以在此处显式编译分片矩阵乘法函数。这会添加所有
# 必要的通信(例如矩阵乘法后的 AllReduce)。
jit_matmul = jax.jit(matmul_square, out_shardings=jax.P('X', None)).lower(In, W).compile()

out = jit_matmul(In, W)

这将自动运行任何分片,并在我们的设备上划分计算。但硬件层面实际上发生了什么?

  1. 首先,我们创建了跨设备分片的 In 和 W注意我们是如何做到这一点的。这是创建具有特定分片的数组的一种方法(即通过在创建函数中添加 device 参数)。另一种方法是使用 jnp.array(....) 正常创建数组,然后执行例如 jax.device_put(..., jax.P('x', 'y'))。还有一种方法是编写一个创建你想要数组的函数,并使用 out_shardings 为你想要的内容进行 jit 编译。。W 沿收缩维度(contracting dimension)进行 2 路分片,而 In 沿收缩维度和输出维度进行 4 路分片。这对应于 W[DY, F] 和 In[BX, DY] 的分片,即一种模型和数据并行。
  2. 如果我们在本地运行(即在一个设备上),matmul_square 只会对输入求平方并执行简单的矩阵乘法。但因为我们将 out_shardings 指定为 P('X', None),输出将沿 batch 分片,但在模型维度上是复制的,并且需要一个 AllReduce 来计算。

使用我们前面章节的记号,这可能会执行如下操作:

  1. Out[BX, F] { UY } = In[BX, DY] *D W[DY, F]
  2. Out[BX, F] = AllReduce(Out[BX, F] { UY })

jax.jit 将自动为我们添加这些!我们实际上可以使用 jit_matmul.as_text() 打印 HLO,并看到以下 HLO(大幅简化):

# 这个 fusion 是分片输入和矩阵的实际矩阵乘法
%fusion = bf16[2,8192]{1,0:T(4,128)(2,1)S(1)} fusion(bf16[2,1024]{1,0:T(4,128)(2,1)} %param, bf16[8192,1024]{1,0:T(8,128)(2,1)S(1)} %copy-done)

# 我们跨设备对部分求和的结果进行 reduce
ROOT %AllReduce = bf16[2,8192]{1,0:T(4,128)(2,1)} AllReduce(bf16[2,8192]{1,0:T(4,128)(2,1)S(1)} %fusion)

我们可以在上面看到矩阵乘法(fusion)和 AllReduce。特别注意形状。bf16[2, 1024] 是激活的局部视图,因为我们的 batch_size=8 被拆分到 4 个设备上,而我们的 d_model=2048 也同样被拆分到 2 个设备上。

这非常神奇! 无论我们的程序多么复杂,Shardy 和 jit 都会尝试为所有中间激活找到分片,并根据需要添加通信。话虽如此,Shardy 也有其缺陷。它可能会犯错。有时你会查看分析结果并注意到出了问题。一个巨大的 AllGather 占据了分析结果的 80%,而它本不需要。发生这种情况时,我们可以尝试通过使用 jax.lax.with_sharding_constraint 显式注释中间张量来纠正编译器。例如,在两个矩阵乘法中,我可以通过以下方式强制中间激活沿 y 维度分片(尽管这不一定是个好主意):

import jax
import jax.numpy as jnp

mesh = jax.make_mesh((4, 2), ('X', 'Y'))

def matmul(x, Win, Wout):
  hidden = jnp.einsum('bd,df->bf', x, Win)
  hidden = jax.lax.with_sharding_constraint(hidden, jax.P('X', 'Y'))
  return jnp.einsum('bf,df->bd', hidden, Wout)

在自动分区世界中,这大约占据了 JAX 并行编程的 60%,在这里你通过 jax.lax.with_sharding_constraint 控制中间分片。但“逗弄编译器”显然不是一种有趣的编程模型。你可以注释每一个中间变量,但仍然不知道是否会得到正确的结果。相反,如果 JAX 自身能够处理并控制分片传播会怎样?

显式分片模式 (Explicit sharding mode)

显式分片(或“类型中的分片”)看起来很像自动分片,但分片传播发生在 JAX 层面!每个 JAX 操作都有一个分片规则,该规则获取操作参数的分片,并为操作结果产生一个分片。你可以使用 jax.typeof 查看产生的分片:

import jax
import jax.numpy as jnp
import jax.sharding as shd
import numpy as np

# 在 TPU v5e 2x2 上运行。这为硬件的两个物理轴分配了名称。
mesh = jax.make_mesh(axis_shapes=(2, 2), axis_names=('X', 'Y'),
                                       axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))

# 这告诉 JAX 将此网格用于所有操作,因此你只需指定分区规范 P。
jax.set_mesh(mesh)

x = jax.device_put(np.arange(16).reshape(8, 2), jax.P('X', 'Y'))

@jax.jit
def f(x):
  print(jax.typeof(x))  # bfloat16[8@X,2@Y]
  out = x * 2
  print(jax.typeof(out))  # bfloat16[8@X,2@Y]
  return out

f(x)

如你所见,JAX 将分片从输入 (x) 传播到了输出 (x),这些分片在追踪时(trace-time)可以通过 jax.typeof 进行检查。对于大多数操作,这些规则简单明了,因为只有一种合理的选择(例如逐元素操作保留相同的分片)。但对于某些操作,如何对结果进行分片是模糊的,在这种情况下 JAX 会抛出追踪时错误,我们要求程序员显式提供 out_sharding 参数(例如 jnp.einsum、jnp.reshape 等)。让我们看另一个出现冲突的例子:

# 我们创建一个矩阵 W 和输入激活 In,它们在我们的设备上分片。
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, out_sharding=jax.P('X', 'Y'))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, out_sharding=jax.P('Y', None))

@jax.jit
def matmul_square(In, W):
  print(jax.typeof(In))  # bfloat16[8@X, 2048@Y]
  print(jax.typeof(W))  # bfloat16[2048@Y, 8192]
  return jnp.einsum('bd,df->bf', jnp.square(In), W)

matmul_square(In, W)  # 这将报错

这段代码报错:Contracting dimensions are sharded and it is ambiguous how the output should be sharded. Please specify the output sharding via the out_sharding parameter. Got lhs_contracting_spec=('Y',) and rhs_contracting_spec=('Y',)

这很棒,因为 einsum 的输出应该如何分片确实是模糊的。输出分片可以是: * P('X', 'Y'),这将引发一个 reduce-scatter,或者 * P('X', None),这将引发一个 all-reduce

与自动模式不同,显式模式在检测到模糊通信时会报错,并要求用户解决。所以在这里你可以这样做:

@jax.jit
def matmul_square(In, W):
  return jnp.einsum('bd,df->bf', jnp.square(In), W, out_sharding=jax.P('X', 'Y'))

out = matmul_square(In, W)
print(jax.typeof(out))  # bfloat16[8@X,8192@Y]

自动模式和显式模式可以通过 jax.sharding.auto_axesjax.sharding.explicit_axes API 进行组合。这是一份非常值得阅读的文档,可以获取更多信息。

通过 shard_map 实现的手动分片模式

虽然 Shardy 是“编译器掌舵”模式,但 JAX shard_map 将一切交到了你手中。你指定输入的分片(类似于 jax.jit),但随后你显式编写所有通信。jax.jit 为你提供程序的全局跨设备视图,而 shard_map 为你提供局部的每个设备视图。

这是一个例子。试着推断这个函数的作用:如果你想在 Colab 中通过模拟网格自行尝试,可以使用以下单元格:import jax; jax.config.update('jax_num_cpu_devices', 8)

import jax
import jax.numpy as jnp
import jax.sharding as shd

mesh = jax.make_mesh((2, 4), ('x', 'y'), (shd.AxisType.Explicit, shd.AxisType.Explicit))
jax.set_mesh(mesh)

x = jnp.arange(0, 512, dtype=jnp.int32, out_sharding=jax.P(('x', 'y')))

# 此函数将在数组的 1/8 上运行。
@jax.shard_map(in_specs=jax.P(('x', 'y')), out_specs=jax.P())
def slice_and_average(x):
  assert x.shape == (512 // 8,)
  return jax.lax.pmean(x[:4], axis_name=('x', 'y'))

out = slice_and_average(x)
assert out.shape == (4,)

这是在做什么? slice_and_average 在每个 TPU 上运行,处理数组的 1/8,从中切片前 4 个元素并在整个网格上取平均值。这意味着我们实际上是在执行 mean(x[:4], x[64:68], x[128:132], …)。这非常酷,因为否则在 JAX 中很难表达这种操作。

为什么要用这个而不是 jax.jit? 如果我们使用 jax.jitslice_and_average 将看到数组的全局视图(整个 [512,] 数组)。我们必须切出这个非均匀的切片,然后执行一个平均,而 XLA 必须正确解释它。XLA 可能会添加错误的通信或感到困惑。在这里,我们看到局部视图并只编写我们需要的通信。

示例 [集合矩阵乘法 (Collective Matmul)]: 举一个更现实的例子,假设我们要实现模型并行,其中激活最初是模型分片的,即 A[BX, DY] * W[D, FY] -> Out[BX, FY]。天真地,我们会通过先 AllGather A,然后进行局部矩阵乘法来实现:

  1. A[BX, D] = AllGatherY(A[BX, DY])
  2. Out[BX, FY] = A[BX, D] *D W[D, FY]

不幸的是,这很糟糕,因为它不允许我们将通信与计算重叠。重叠它们可以使用“集合矩阵乘法”来完成,如 Wang 等人 2023 中所述。该算法基本上如下:

我们可以很容易地用 jax.shard_map 实现这一点:

import functools

import jax
import jax.numpy as jnp
import jax.sharding as shd
import numpy as np

# 这旨在在 TPU v5e-8 运行时上运行。如果你无法获得,
# 请尝试设置 jax.config.update('jax_num_cpu_devices', 8)。
#
mesh = jax.make_mesh(axis_shapes=(2, 4), axis_names=('X', 'Y'),
                                       axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))
jax.set_mesh(mesh)

B, D, F = 1024, 2048, 8192
A = jnp.arange(np.prod((B, D))).reshape((B, D))
W = jnp.arange(np.prod((D, F))).reshape((D, F))

A = jax.device_put(A, jax.P('X', 'Y'))
W = jax.device_put(W, jax.P(None, 'Y'))

@functools.partial(jax.jit, out_shardings=jax.P('X', 'Y'))
def matmul(lhs, rhs):
  return lhs @ rhs

def collective_matmul_allgather_lhs_contracting(lhs, rhs):
  # lhs 是循环操作数;rhs 是局部操作数
  axis_size = jax.lax.axis_size('Y')  # 在此示例中 axis_size = 4
  idx = jax.lax.axis_index('Y')

  chunk_size = lhs.shape[1]
  assert rhs.shape[0] % chunk_size == 0

  def f(i, carrys):
    accum, lhs = carrys
    rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
    # 对一个块进行矩阵乘法
    update = lhs @ rhs_chunk
    # 向左循环置换
    lhs = jax.lax.ppermute(
        lhs,
        axis_name='Y',
        perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
    )
    return accum + update, lhs

  accum = jnp.zeros((lhs.shape[0], rhs.shape[1]), dtype=lhs.dtype)
  accum = jax.lax.pvary(accum, ('X', 'Y'))
  accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs), unroll=True)

  # 在最后一次置换后计算最后一个块,使 lhs 保持在我们发现它时的状态
  i = axis_size - 1
  rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
  update = lhs @ rhs_chunk
  return accum + update

jit_sharded_f = jax.jit(jax.shard_map(
  collective_matmul_allgather_lhs_contracting,
  in_specs=(jax.P('X', 'Y'), jax.P(None, 'Y')), out_specs=jax.P('X', 'Y')))

shmapped_out = jit_sharded_f(A, W)
expected_out = matmul(A, W)

np.testing.assert_array_equal(shmapped_out, expected_out)

这非常巧妙!我们可以对此进行基准测试,发现它也快得多!这里是默认 jit 矩阵乘法的分析图,耗时 311us,开头有一个大的阻塞式 AllGather:

assets/img/not-overlapped.png

这里是上面的版本,耗时 244 us。你可以看到分析图中没有 AllGather。全都是有用的工作!我们的 FLOPs 利用率也高得多。

assets/img/overlapped.png

还值得注意的是,收缩维度不分片时的矩阵乘法时间是 224us,所以我们在这里非常接近不分片的基准。这是一个很好的例子,说明了你为了提高 TPU 利用率最终可能会进行的性能工程。欲了解更多 shard_map 示例,此说明非常棒

现在,这里有几个有用的习题,尝试使用 jax.jitshard_map 来实现!

习题

这里有一些随机的 JAX 相关问题。我稍后会再添加一些。对于所有这些习题,你需要在 Colab 中拥有一些 TPU。你可以使用带有 TPUv2-8 的公开 Colab。从现在起,我们假设你有 N 个可用设备。

问题 1:A 为形状为 float32[SX, DY] 的激活数组,其中 X * Y = N。执行以下操作:

  1. 编写一个 JAX 函数,计算每个 (X, Y) 分片内的平均值,即它返回一个大小为 [X, Y] 的数组,其中 arr[i, j] 是分片 (i, j) 的平均值。分别使用 jax.jitshard_map 完成此操作。分析每个函数的耗时。是否添加了通信?提示:不应该有通信,但有时 XLA 还是会添加。

  2. 编写一个 JAX 函数,返回 每个分片 X 内 的 roll(x, shift, axis=0) - x(对于某个 shift)。我还不至于受虐到让你在 jax.jit 中做这个,所以直接用 shard_map 做吧。

点击此处查看答案。 第 1 部分:这是第 1 部分的一种解决方案。注意我们在 `jax.jit` 解决方案中必须做的相当复杂的重塑 (reshape)。
import numpy as np

import jax
import jax.numpy as jnp

mesh = jax.make_mesh((4, 2), ('X','Y'))

average_shmap = jax.shard_map(
    lambda x: x.mean(keepdims=True),
    mesh=mesh,
    in_specs=jax.P('X','Y'), out_specs=jax.P('X','Y')
)

def average(x):
  X, Y = mesh.axis_sizes
  return x.reshape(X, x.shape[0] // X, Y, x.shape[1] // Y).mean(axis=(1, 3))

average_jit = jax.jit(average, out_shardings=jax.NamedSharding(mesh, jax.P('X','Y')))

x = jnp.arange(8 * 64 * 8, dtype=jnp.int32).reshape(8 * 64, 8)
x = jax.device_put(x, jax.NamedSharding(mesh, jax.P('X','Y')))

y1 = average_shmap(x)
y2 = average_jit(x)

np.testing.assert_array_equal(y1, y2)
第 2 部分:这是第 2 部分的一个类似解决方案。
import numpy as np

import jax
import jax.numpy as jnp

import functools

P = jax.sharding.PartitionSpec

mesh = jax.make_mesh((4, 2), ('X','Y'))

def shift_shmap(x, shift: int):
  shmapped = jax.shard_map(
      lambda x: jnp.roll(x, shift, axis=0),
      mesh=mesh,
      in_specs=jax.P('X','Y'), out_specs=jax.P('X','Y')
  )
  return shmapped(x)

@functools.partial(jax.jit, static_argnames=['shift'], out_shardings=jax.NamedSharding(mesh, jax.P('X','Y')))
def shift_jit(x, shift: int):
  X, Y = mesh.axis_sizes
  reshaped = x.reshape(X, x.shape[0] // X, -1)
  return jnp.roll(reshaped, shift, axis=1).reshape(x.shape[0], x.shape[1])

x = jnp.arange(8 * 64 * 8, dtype=jnp.int32).reshape(8 * 64, 8)
x = jax.device_put(x, jax.NamedSharding(mesh, jax.P('X','Y')))

y1 = shift_shmap(x, 5)
y2 = shift_jit(x, 5)

np.testing.assert_array_equal(y1, y2)

问题 2: 这里我们将一起制作一个基础的“混合专家” (mixture of experts) 模型。设 W: float32[EX, D, F] 为一组 E 个“专家”矩阵。设 A: float32[SX, D](我们的激活值)以及 B: int32[SX] 为一组“路由分配”,其中 B[i] 是 [0, E) 范围内的整数,告诉我们希望哪个矩阵处理该激活值。我们想编写一个 JAX 函数,返回 Out[i] = W[B[i]] @ A[i]

  1. 我们先从完全忽略分片开始。让所有这些张量都足够小,以便它们能放入一个设备中。编写该函数的局部实现。确保你没有实例化形状为 [S, D, F] 的数组!提示:尝试将 token 排序到一个形状为 [E, S, D] 的新缓冲区中,并注意掩码(想想为什么我们需要第二个维度的大小为 S?)。

  2. 如果你只是对上述方法进行 jax.jit,会发生一些事情。分析这个过程并查看它决定执行哪些通信。耗时多久?

  3. 你会注意到上述方法的一个问题是,它可能会在本地收集完整的激活集 A,即 AllGatherX([SX, D])。这不仅在通信上很昂贵,而且如果我们在本地无法容纳完整的激活集,那么在内存上也会非常昂贵。使用 shard_map 和显式通信来实现上述功能。

    1. 第一步,使用 jax.lax.all_gather 并在 (a) 中重新排序可能最简单。

    2. 第二步,尝试避免实例化任何大小为 [E, S, D] 的数组,即尝试在 jax.lax.while_loop 中使用 jax.lax.all_to_all 以不规则(ragged)的方式执行计算。这样,你可以避免实例化完整的激活值并浪费计算资源在填充 (padding) 上。这比你最初的实现快多少?

  4. 大多数 MoE 路由到多个 (k) 专家,然后对结果取平均。重构上述代码以实现此功能。在这种情况下,设 B: int32[S, k] 为要路由到的 k 个专家。

点击此处查看(部分)答案。 1/2. 对于第 (1) 部分,你有很多选择。这里有一个选项,只是通过掩码迭代专家。
def moe_local(W: jnp.ndarray, A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray:
    S, _ = A.shape
    E, _, F = W.shape

    def expert_forward(carry, e):
        output = carry  # [S, F]
        mask = (B == e)[:, None]  # [S, 1]
        expert_result = A @ W[e]  # [S, F] - 该专家对所有 token 的变换
        output = output + expert_result * mask  # 只保留分配的 token 的结果
        return output, None

    output = jnp.zeros((S, F))
    output, _ = lax.scan(expert_forward, output, jnp.arange(E))

    return output
你也可以使用 `jax.lax.ragged_dot`,它会执行类似的操作但效率更高。 3. 我在这里只简述一下伪代码(如果你有干净的解决方案,欢迎添加):
chunk_size = 128
def matmul(W, x, B):
  i = 0
  x = # 根据分配对 x 进行排序
  while (chunk := x[i:i+chunk_size].any()):
     chunk = all_to_all(chunk)
     out = matmul_local(W, chunk)
  return concat(out)
基本思想是迭代数组的块,对它们进行排序并执行 all_to_all,然后执行局部 FLOP。

问题 3: 上述集合矩阵乘法示例实际上与真实的 LLM 非常相关。让我们微调该示例以完成整个 Transformer 栈。

  1. 作为一个练习,我们先来实现一个 AllReduce 集合矩阵乘法,即 A[BX, DY] *D W[DY, F] -> Out[BX, F]。注意输出不是复制的。朴素算法在上面已经讨论过,基本上只是一个局部矩阵乘法,后跟一个 AllReduce。尝试制作一个通信重叠的“集合”版操作。提示:对输出维度进行分块,并随意使用 jax.lax.psum(又名 AllReduce)。 注意:由于 XLA 处理此操作的方式,它实际上可能不会比基准线更快。

  2. 与上述 AllReduce 集合矩阵乘法互补的是 ReduceScatter 集合矩阵乘法,如 Tmp[BX, FY] *F W2[FY, D] -> Out[BX, DY]。这发生在 Transformer 的下投影 (down-projection) 矩阵中。在 JAX 中实现一个集合的、重叠的版本。注意只传递你需要的最小数据量。提示:尝试在累加结果时对其进行置换。

  3. 将这两个部分组合成一个端到端的 Transformer 块,执行 In[BX, DY] *D Win[D, FY] *F Wout[FY, D] -> Out[BX, DY] 并带有重叠通信。如前所述,由于我们在此省略了非线性激活函数,因此不能先计算 $W_{in} \cdot W_{out}$。 这比 jax.jit 实现快多少?

问题 4: 上述实现的所有集合矩阵乘法都是单向的:它们只在一个方向上置换。重写集合 AllReduce 矩阵乘法和集合 ReduceScatter 矩阵乘法,以使用双向通信。这些能快多少?

第 10 部分到此结束。基本上就是这样!有关最终结论和进一步阅读,请点击此处

第 11 章

结语与延伸阅读

感谢你读完整本书,也恭喜你一路坚持到了最后。 在正式结束之前,我们先做几则致谢:

致谢

这份文档凝聚了 Google DeepMind 许多人的大量共同投入,我们想在这里简要致谢!

我们也要感谢许多在整个过程中提出关键反馈的人,尤其包括 Zak Stone、Nikhil Sethi、Caitlin Stanton、Alek Dimitriev、Sridhar Lakshmanamurthy、Albert Magyar、Diwakar Gupta、Jeff Dean、Corry Wang、Matt Johnson、Peter Hawkins,以及许多其他人。也感谢 Ruiqi Gao 在 HTML 格式化方面提供的帮助。

感谢大家!

在离开之前,你或许也会想读一读关于 NVIDIA GPU 的全新[第 12 部分](../gpus)!

延伸阅读

还有许多相关资料值得一读,包括以下内容:

在这个领域,仍然有很大空间容纳更系统、更全面的写作,因此我们希望这份手稿能鼓励更多相关工作出现!我们也相信,这是一个非常值得学习和研究的方向。在很多情况下,即使手头没有大量硬件加速器,也依然可以开展这方面的工作。

反馈

如果你有意见或问题,请告诉我们,以便我们继续改进这份内容。你可以通过 jacobaustin123 [at] gmail [dot] com 联系我们的通讯作者 Jacob Austin,或者在 GitHub 上通过 issue、pull request 或 discussion 提出修改建议。

第 12 章

如何理解 GPU

什么是 GPU?

现代机器学习 GPU(例如 H100、B200)本质上可以看作是一大堆专门做矩阵乘法的计算核心(称为 Streaming Multiprocessor,即 SM),再连上一条高速内存(称为 HBM)。如下图所示:

<b>图:</b>展示 H100 或 B200 GPU 抽象布局的示意图。H100 有 132 个 SM,而 B200 有 148 个。我们这里较宽泛地用“Warp Scheduler”来指代一组 32 路 CUDA SIMD 核心<i>以及</i>负责给它们分发工作的调度器。注意它看起来和 TPU 多么相似!
图:展示 H100 或 B200 GPU 抽象布局的示意图。H100 有 132 个 SM,而 B200 有 148 个。我们这里较宽泛地用“Warp Scheduler”来指代一组 32 路 CUDA SIMD 核心以及负责给它们分发工作的调度器。注意它看起来和 TPU 多么相似!

每个 SM,类似于 TPU 的 Tensor Core,都有一个专用的矩阵乘法核心(不幸的是它也叫 Tensor CoreGPU Tensor Core 指的是 SM 内部负责矩阵乘法的子单元,而 TPU TensorCore 指的是包含 MXU、VPU 及其他组件的上层整体单元。)、一个向量算术单元(称为 Warp SchedulerNVIDIA 并没有给这个单元一个特别好的名字,所以我们这里只是在若干个不太理想的选项中选了相对最合适的一个。Warp Scheduler 严格说主要是把工作分发给一组 CUDA 核心的单元,但在这里我们用它来同时表示控制单元以及它所控制的那组核心。),以及一个高速片上缓存(称为 SMEM)。与一个 TPU 至多只有 2 个独立 “Tensor Core” 不同,现代 GPU 拥有 100 多个 SM(H100 上是 132 个)。这些 SM 中的每一个都远不如 TPU Tensor Core 强大,但整个系统因此更灵活。每个 SM 基本上都相互独立,所以 GPU 可以同时执行数百个彼此独立的任务。尽管 SM 彼此独立,但为了达到峰值性能,它们通常仍需要协同工作,因为它们共享一个容量有限的 L2 缓存。

让我们更细致地看一下 H100 的一个 SM:

<b>图:</b>H100 的一个 SM 示意图(<a href='https://wccftech.com/nvidia-hopper-gh100-gpu-official-5nm-process-worlds-fastest-hpc-chip-80-billion-transistors-hbm3-memory/'>来源</a>),展示了 4 个<i>子分区</i>,每个子分区都包含一个 Tensor Core、Warp Scheduler、寄存器文件,以及多组不同精度的 CUDA Core。底部附近的 “L1 Data Cache” 就是 256kB 的 SMEM 单元。B200 与之类似,但增加了大量 Tensor Memory(TMEM)来喂给体积更大的 Tensor Core。
图:H100 的一个 SM 示意图(来源),展示了 4 个子分区,每个子分区都包含一个 Tensor Core、Warp Scheduler、寄存器文件,以及多组不同精度的 CUDA Core。底部附近的 “L1 Data Cache” 就是 256kB 的 SMEM 单元。B200 与之类似,但增加了大量 Tensor Memory(TMEM)来喂给体积更大的 Tensor Core。

每个 SM 被划分为 4 个相同的象限,NVIDIA 称之为 SM 子分区(SM subpartition)。每个子分区都包含一个 Tensor Core、16k 个 32 位寄存器,以及一个名为 Warp Scheduler 的 SIMD/SIMT 向量算术单元;该单元中的各条算术通道(ALU)被 NVIDIA 称为 CUDA Cores。每个分区中最核心的组件可以说是 Tensor Core,因为它负责矩阵乘法,也贡献了绝大多数 FLOPs/s,但这并不是唯一值得关注的部分。

CUDA cores 比 TPU 的 VPU 更灵活: 自 V100 起,GPU CUDA cores 使用的是所谓的 SIMT(Single Instruction Multiple Threads,单指令多线程)编程模型,而 TPU 使用的是 SIMD(Single Instruction Multiple Data,单指令多数据)模型。与 TPU VPU 中的 ALU 一样,同一子分区中的 CUDA cores 在每个周期必须执行同一个操作(例如如果一个 core 在做两个浮点数相加,那么该子分区中的其他 CUDA core 也必须这样做)。但与 VPU 不同的是,每个 CUDA core(或者说 CUDA 编程模型中的“线程”)都有自己的指令指针,因此可以被独立地 编程。当同一 warp 内两个线程被要求执行不同操作时,实际上你会把 两种 操作都执行一遍,只是在某一步里把不需要该操作的核心屏蔽掉。

<b>图:</b>一组线程内部发生 warp divergence 的示例(<a href='https://images.nvidia.com/content/volta-architecture/pdf/volta-architecture-whitepaper.pdf'>来源</a>)。白色空白区域表示至少有一部分物理 CUDA 核心处于停顿状态。
图:一组线程内部发生 warp divergence 的示例(来源)。白色空白区域表示至少有一部分物理 CUDA 核心处于停顿状态。

这让线程级编程非常灵活,但代价是如果 warp 经常分歧,性能会悄无声息地下降。线程在可访问的内存上也更灵活;VPU 只能处理连续内存块,而 CUDA cores 可以访问共享寄存器中的单个浮点数,并维护线程私有状态。

CUDA core 的调度也更灵活: SM 的运行有点像多线程 CPU,因为它能够并发“调度”许多程序(即 warp,每个 SM 最多 64 个),但每个 Warp Scheduler 在每个时钟周期只会执行一个程序。被调度到某个 SM 上的 warp 被称为 “resident”。 Warp Scheduler 会自动在活跃 warp 之间切换,以隐藏诸如内存加载这样的 I/O 操作。相比之下,TPU 通常基本可看作单线程。

内存

除了计算单元之外,GPU 还有一套层级化内存结构,其中最大的当然是 HBM(主 GPU 内存),下面还有一系列更小的缓存(L2、L1/SMEM、TMEM、寄存器内存)。

GPU 规格汇总

以下总结了近几代 GPU 的规格。对于同一款 GPU 的不同变体,SM 数量、时钟频率和 FLOPs 会略有差异。先看内存容量:

GPU 代际 时钟频率 每芯片 SM 数 每个 SM 的 SMEM 容量 每芯片 L2 容量 每芯片 HBM 容量
V100 Volta 1.25GHz/1.38GHz 80 96kB 6MB 32GB
A100 Ampere 1.10GHz/1.41GHz 108 192kB 40MB 80GB
H100 Hopper 1.59GHz/1.98GHz 132 256kB 50MB 80GB
H200 Hopper 1.59GHz/1.98GHz 132 256kB 50MB 141GB
B200 Blackwell ? 148 256kB 126MB 192GB

所有代际每个 SM 都有 256kB 的寄存器内存。Blackwell 另外还为每个 SM 增加了 256kB 的 TMEM。下面是各芯片的 FLOPs 和带宽数据:

GPU 代际 每芯片 HBM 带宽 每芯片 FLOPs/s(bf16/fp16) 每芯片 FLOPs/s(fp8/int8) 每芯片 FLOPs/s(fp4)
V100 Volta 9.0e11
A100 Ampere 2.0e12 3.1e14 6.2e14
H100 Hopper 3.4e12 9.9e14 2.0e15
H200 Hopper 4.8e12 9.9e14 2.0e15
B200 Blackwell 8.0e12 2.3e15 4.5e15 9.0e15

我们不包含 B100,因为它并未大规模量产。尽管 NVIDIA 曾推出 B100 这一代产品,但据称由于设计缺陷,它们只能短暂销售和生产,且无法接近标称规格运行。由于散热和功耗问题,它们在接近峰值 FLOPs 时容易降频。 某些规格会随 GPU 的具体版本略有不同,因为 NVIDIA GPU 并不像 TPU 那样标准化。

下面这张对照表可以帮助快速比较 GPU 与 TPU 的各个部件:

GPU TPU 它是什么?
Streaming Multiprocessor (SM) Tensor Core 含有多个子单元的核心“单元格”
Warp Scheduler VPU SIMD 向量算术单元
CUDA Core VPU ALU SIMD ALU
SMEM (L1 Cache) VMEM 高速片上缓存内存
Tensor Core MXU 矩阵乘法单元
HBM (aka GMEM) HBM 高带宽大容量内存

芯片层面的 GPU 与 TPU 对比

GPU 最初是为电子游戏渲染而生,但随着 2010 年代深度学习的兴起,它们越来越像专用矩阵乘法机器,换句话说,也越来越像 TPU。在深度学习爆发之前,GPU(Graphics Processing Unit)确实主要用于图形处理,尤其是电子游戏。游戏通常用数百万个小三角形表示物体,而渲染器会将这些三角形投影并栅格化成二维图像,每秒在屏幕上刷新 30 到 60 次(这个频率就是帧率)。栅格化需要把这些三角形投影到相机坐标系中,并不断计算哪些三角形覆盖哪些像素,规模可达每秒数十亿次。显然,这个过程计算代价非常高,而且这还只是开始。之后还要把与光线相交的若干半透明三角形的颜色混合起来,为每个像素上色。GPU 被设计成以极高速度完成这些操作,同时保持高度通用性;你需要让许多不同的 GPU 工作负载(称为 “shader”)同时运行,而没有哪一种单一操作绝对占主导。因此,面向消费者图形的 GPU 当然也能做矩阵乘法,但那并不是它们最初的主要使命。 从某种程度上说,这段历史解释了现代 GPU 为何长成今天这个样子。它们并不是为 LLM 或 ML 模型纯粹从零设计出来的,而是通用加速器,因此硬件天然追求一定程度的“通用性”,这既可能是优点,也可能是负担。GPU 应用于新任务时更常“直接就能跑起来”,对优秀编译器的依赖也远低于 TPU。但反过来,这也让 GPU 更难推理,更难真正跑到 roofline 性能,因为编译器的很多特性都可能成为瓶颈。

GPU 更模块化。 TPU 有 1 到 2 个大型 Tensor Core,而 GPU 则有数百个较小的 SM。类似地,每个 TPU Tensor Core 里有一个由 4 个可独立编程的 8x128 单元组成的大型 VPU(总共 4096 个 ALU);相比之下,H100 有 132 * 4 = 528 个独立 SIMD 单元,每个宽度为 32(总共 1.6 万个 ALU)。下面是一个 1:1 的 GPU/TPU 对照,能更清楚地体现这个差异:

GPU TPU H100 数量 TPU v5p 数量
SM(streaming multiprocessor) Tensor Core 132 2
Warp Scheduler VPU 槽位 528 8
SMEM(L1 cache) VMEM 32MB 128MB
寄存器 Vector Registers(VRegs) 32MB 256kB
Tensor Core MXU 528 8

这种模块化差异一方面让 TPU 更便宜、更容易理解,但另一方面也更依赖编译器替你把事情做好。由于 TPU 只有单一控制线程,且仅支持 VPU 级别的向量化指令,编译器必须手动对所有内存加载以及 MXU/VPU 工作进行流水化,以避免停顿。GPU 程序员则可以直接启动几十个不同 kernel,让它们运行在完全独立的 SM 上。但另一方面,这些 kernel 也可能因为疯狂冲击 L2 缓存或没能合并内存访问而性能极差;因为运行时太多东西都由硬件控制,导致你很难知道幕后到底发生了什么。因此,TPU 往往能以更少工作更接近峰值 roofline 性能。

历史上,单块 GPU 通常比同级别 TPU 更强大(也更昂贵): 单块 H200 的 FLOPs/s 接近 TPU v5p 的 2 倍,HBM 也多出 1.5 倍。与此同时,Google Cloud 上 H200 的标价大约是每小时 \$10,而 TPU v5p 大约是每小时 \$4。TPU 往往更依赖把多块芯片联网起来,而 GPU 则更强调单卡能力。

TPU 拥有更多高速缓存内存。 TPU 的 VMEM 也远比 GPU 的 SMEM(加上 TMEM)大,而且这种内存可以用来存放权重和激活值,并以极高速度被加载和使用。如果你能持续把模型权重存进 VMEM 或预取进 VMEM,那么在 LLM 推理时 TPU 可能会更快。

小测验 1:GPU 硬件

下面是一些用于检验上述内容的小题。答案都给出了,但最好还是先拿起纸笔自己算一算。

问题 1【CUDA cores】: H100 有多少个 fp32 CUDA core(ALU)?B200 呢?这和 TPU v5p 中独立 ALU 的数量相比如何?

点击这里查看答案。 **答案:** H100 有 132 个 SM,每个 SM 有 4 个子分区,每个子分区含 32 个 fp32 CUDA core,因此总数是 `132 * 4 * 32 = 16896` 个 CUDA core。B200 有 `148` 个 SM,因此总数为 `18944`。TPU v5p 有 2 个 TensorCore(通常通过 Megacore 连接),每个都有一个 VPU,其通道数是 `(8, 128)`,并且每条通道上有 4 个独立 ALU,因此共有 `2 * 4 * 8 * 128 = 8192` 个 ALU。也就是说,它大约只有 H100 一半数量的向量通道,而运行频率却大致相近。

问题 2【向量 FLOPs 计算】:单块 H100 有 132 个 SM,时钟频率为 1.59GHz(boost 时最高 1.98GHz)。假设每个 ALU 每个周期可做 1 个向量操作,那么每秒可以完成多少 vector fp32 FLOPs?开启 boost 后呢?与 matmul FLOPs 相比如何?

点击这里查看答案。 **答案:** `132 * 4 * 32 * 1.59e9 = 26.9TFLOPs/s`。开启 boost 后为 33.5 TFLOPs/s。这只有 [规格表](https://www.nvidia.com/en-us/data-center/h100/) 上报告值的一半,因为严格说来一个周期内我们可以执行一次 FMA(融合乘加),那会算作 2 个 FLOPs,但这在大多数情况下并不实用。H100 可以做到 990 bfloat16 matmul TFLOPs/s,因此忽略 FMA 后,Tensor Core 的 FLOPs/s 大约高 30 倍。

问题 3【GPU matmul 强度】: H100 的峰值 fp16 matmul 强度是多少?B200 呢?fp8 呢?这里所说的强度,是指 matmul FLOPs/s 与内存带宽之比。

点击这里查看答案。 **答案:** 对 H100,我们有峰值 990e12 fp16 FLOPs 和 3.35e12 bytes/s 的带宽。因此临界强度是 `990e12 / 3.35e12 = 295`,与 TPU 的 240 相当接近。B200 则是 `2250e12 / 8e12 = 281`,也非常相似。这意味着,与 TPU 类似,要让一个 matmul 进入计算受限区,我们大约需要 280 左右的 batch size。 对 H100 和 B200 来说,fp8 FLOPs 都恰好是 fp16 的 2 倍,所以峰值强度也翻倍,分别到 590 和 562;不过从某种意义上讲,如果你考虑到权重也很可能以 fp8 加载,那么这个强度也可以认为基本保持不变。

问题 4【Matmul 运行时间】: 根据问题 3 的结果,你预计单块 B200 上执行 fp16[64, 4096] * fp16[4096, 8192] 需要多久?fp16[512, 4096] * fp16[4096, 8192] 呢?

点击这里查看答案。 根据上面的分析,我们知道当 batch size 低于 281 个 token 时会受通信带宽限制。因此第一个完全是带宽受限。我们需要读取或写回 $2BD + 2DF + 2BF$ 字节(`2*64*4096 + 2*4096*8192 + 2*64*8192=69e6`),带宽为 `8e12` bytes/s,因此大约耗时 `69e6 / 8e12 = 8.6us`。实际中我们通常拿不到全部带宽,因此更可能接近 10 到 12us。增大 batch size 后,就完全进入计算受限,因此预计 `T=2*512*4096*8192/2.3e15=15us`。同样,我们也只会拿到总 FLOPs 的一部分,因此实际更可能接近 20us。

问题 5【L1 cache 容量】: H100 的总 L1/SMEM 容量是多少?寄存器内存呢?与 TPU 的 VMEM 容量相比如何?

点击这里查看答案。 **答案:** 每个 SM 有 256kB 的 SMEM 和 256kB 的寄存器内存,所以各自总量大约为 33MB(`132 * 256kB`)。两者合起来约为 66MB。这大约是现代 TPU 120MB VMEM 的一半,不过 TPU 的寄存器内存总共只有 256kB!TPU 的 VMEM 延迟甚至比 SMEM 还低,这也是为什么寄存器内存在 TPU 上没那么关键(溢出到 VMEM 和从 VMEM 回填都很便宜)。

问题 6【估算 B200 时钟频率】: NVIDIA 在这里声称 B200 可提供 80TFLOPs/s 的 vector fp32 计算。已知每个 CUDA core 在一次 FMA(融合乘加)操作中每周期可执行 2 个 FLOPs,请估算其峰值时钟频率。

点击这里查看答案。 **答案:** 我们有 `148 * 4 * 32 = 18944` 个 CUDA core,因此每周期可做 `18944 * 2 = 37888 FLOPs / cycle`。所以 `80e12 / 37888 = 2.1GHz`,这是一个偏高但仍合理的峰值时钟频率。B200 通常采用液冷,因此更高频率也就更说得通。

问题 7【估算 H100 上向量加法运行时间】: 根据上面的数字,计算单块 H100 上把两个 fp32[N] 向量相加理论上需要多久。请同时计算 $T_\text{math}$ 和 $T_\text{comms}$。这个操作的算术强度是多少?如果你能拿到机器,也可以在 PyTorch 或 JAX 中分别对 N = 1024N=1024 * 1024 * 1024 实际运行看看。结果如何?

点击这里查看答案。 **答案:** 首先,相加两个 `fp32[N]` 向量需要执行 N 次 FLOPs,同时要加载 `4 * N * 2` 字节并写回 4 * N 字节,总共是 `3 * 4 * N = 12N` 字节。因此两者比值为 `总 FLOPs / 总字节 = N / 12N = 1 / 12`,这个数可以说低得可怜。 正如前面算过的那样,忽略 FMA 的话,开启 boost 后我们大约可做到 33.5 TFLOPs/s。但这只有在所有 CUDA cores 都被用上时才成立。对于 `N = 1024`,我们最多也只能用到 1024 个 CUDA core,也就是 8 个 SM,因此实际会更慢(如果假设是计算受限,粗略看会慢约 16 倍)。同时,我们还有 3.35e12 bytes/s 的内存带宽。因此峰值硬件强度是 `33.5e12 / 3.35e12 = 10`。值得注意的是,这个强度在最近几代 GPU 上几乎保持不变。H100 上是 33.5 / 3.5,B200 上则是 80 / 8。原因并不清楚,但这个观察很有意思。 因此这个操作显然严重受通信带宽限制。运行时间就是 $$T = \max(T_\text{comms}, T_\text{math}) = \frac{12 \cdot N}{\text{3.35e12}} = \frac{N}{\text{2.8e11}}$$ 对 `N = 65,536`,这大约是 0.23us。实际中我们在 JAX 里看到大约 1.5us,这是合理的,因为这里很明显是延迟受限。对于 `N = 1024 * 1024 * 1024`,roofline 约为 3.84ms,而我们看到 4.1ms,表现不错!

网络互连

网络互连是 GPU 和 TPU 差异最大的地方之一。正如我们已经看到的,TPU 通过 2D 或 3D 环面(torus)连接,每个 TPU 只与自己的相邻 TPU 相连。这意味着两个 TPU 间传递消息时,必须经过中间所有 TPU,也迫使我们只能在整张网格上使用均匀的通信模式。虽然这在某些方面不太方便,但它也意味着每个 TPU 的链路数是常数,因此我们可以在不丢失带宽的前提下扩展到任意大的 TPU “pod”。

相比之下,GPU 使用的是更传统的分层树形交换网络。一组 8 张 GPU(GB200 则最多可达 72 张)构成 节点(node)。这些 GPU 在单跳内通过高带宽互连 NVLink 相连,而多个节点则通过接在每张 GPU 上的 NIC,使用带宽较低的 InfiniBand(IB)或 Ethernet 网络,组合成更大的单元(称为 SU,即 Scalable Unit)。再往上,还可以通过更高层交换机构成任意更大的系统。“节点”这个词本身有歧义,可能表示 NVLink 域,也就是通过 NVLink 完全互连的一组 GPU;也可能表示连接到同一台 CPU 主机的一组 GPU。在 B200 之前,这两者通常是同一个概念,但在 GB200 NVL72 中,我们有一个包含 72 张 GPU 的 NVLink 域,但每台主机仍然只连接 8 张 GPU。本文中“节点”指的是 NVLink 域,但这一用法并非没有争议。

<b>图:</b>典型 H100 网络示意图。一组 8 张 GPU 通过 NVSwitch(也称 NVLink switch)连接为一个节点或 NVLink 域,而这些节点之间再通过交换式 InfiniBand Fabric 互连。H100 在 NVLink 域内每张卡大约有 450GB/s 的出口带宽,而每个节点接入 IB 网络的总出口带宽为 400GB/s。
图:典型 H100 网络示意图。一组 8 张 GPU 通过 NVSwitch(也称 NVLink switch)连接为一个节点或 NVLink 域,而这些节点之间再通过交换式 InfiniBand Fabric 互连。H100 在 NVLink 域内每张卡大约有 450GB/s 的出口带宽,而每个节点接入 IB 网络的总出口带宽为 400GB/s。

节点级别

GPU 节点是一个较小的基本单元,通常由 8 张 GPU 组成(GB200 时最多可达 72 张),这些 GPU 通过全互连、全带宽、低延迟的 NVLink 互连在一起。有人曾把 NVLink 形容为某种“强化版 PCIe”:延迟低、协议开销小,但并不是为可扩展性和容错而设计;而 InfiniBand 更像 Ethernet,是面向更大、更可能丢包的网络设计的。 每个节点内部包含若干高带宽 NVSwitch,用于在本地各 GPU 之间转发数据包。节点内部的真实拓扑在不同代际之间变化很大,包括每个节点的交换机数目都变过;但对于 H100,一个节点有 4 个 NVSwitch,GPU 到它们的连接模式是 5 + 4 + 4 + 5,如下图所示:

<b>图:</b>从 Pascal(P100)开始的节点,即 NVLink 域示意图。自 Volta(V100)起,我们通过一组交换机在节点内实现了全互连。H100 节点有 4 个 NVSwitch,分别通过 25GB/s 链路连接到全部 8 张 GPU。
图:从 Pascal(P100)开始的节点,即 NVLink 域示意图。自 Volta(V100)起,我们通过一组交换机在节点内实现了全互连。H100 节点有 4 个 NVSwitch,分别通过 25GB/s 链路连接到全部 8 张 GPU。

在 Hopper 这一代(NVLink 4.0)中,每条 NVLink 链路都有 25GB/s 的全双工这里的全双工是指每个方向都是 25GB/s,而且两个方向彼此独立。你总共可以在链路上发送 50GB/s 的流量,但每个方向最多 25GB/s。 带宽(B200 则是 50GB/s),因此每张 GPU 向网络的总全双工带宽为 18 * 25=450GB/s。超大规模的 NVSwitch 最多有 64 个 NVLink 端口,这意味着一个 8xH100 节点配上 4 个交换机,理论上可处理 64 * 25e9 * 4=6.4TB/s 的带宽。下面概括了这些数字随 GPU 代际的变化:

NVLink 代际 NVSwitch 代际 GPU 代际 NVLink 带宽(GB/s,全双工) 每张 GPU 的 NVLink 端口数 节点内 GPU 到 GPU 带宽(GB/s,全双工) 节点大小(NVLink 域) 每节点 NVSwitch 数
3.0 2.0 Ampere 25 12 300 8 6
4.0 3.0 Hopper 25 18 450 8 4
5.0 4.0 Blackwell 50 18 900 8/72 2/18

Blackwell(B200)的节点大小是 8 张 GPU。GB200 NVL72 则支持 72 张 GPU 的更大 NVLink 域。我们会同时展示这两种系统的细节。

小测验 2:GPU 节点

下面再来看一些关于网络的问答题。我尤其推荐把这部分认真算一遍,因为它能迫使你真正推导通信模式。

问题 1【H100 节点总带宽】: 一个带有 4 个交换机的 8xH100 节点,总带宽是多少?提示:同时考虑 NVLink 和 NVSwitch 的带宽。

点击这里查看答案。 **答案:** 我们有第 4 代的 4 个 NVSwitch,每个具有 `64 * 25e9=1.6TB/s` 的单向带宽。因此交换机层面一共是 `4 * 1.6e12=6.4e12`。不过注意,每张 GPU 实际最多只能承受 450GB/s 的单向带宽,因此总量上限是 `450e9 * 8 = 3.6TB/s`。因为这个数更小,所以峰值带宽是 3.6TB/s。

问题 2【双向切分带宽】: 双向切分带宽(bisection bandwidth)定义为:把网络均分为两半后,任意一种平分方式下两半之间最小可用带宽是多少。换句话说,把网络拆成两个大小相等的部分后,有多少带宽跨越这条切分边界?你能算出一个 8x H100 节点的双向切分带宽吗?提示:双向切分带宽通常包含两个方向上的流量。

点击这里查看答案。 **答案:** 任意均匀切分都会让每边有 4 张 GPU,而每张 GPU 都可以向另一半出口 `450GB/s`。把双向流量都算上,总共有 `8 * 450GB/s` 字节流跨越切分边界,即 3.6TB/s 的双向切分带宽。这也是 NVIDIA 在例如[这里](https://hc34.hotchips.org/assets/program/conference/day2/Network%20and%20Switches/NVSwitch%20HotChips%202022%20r5.pdf)报告的数值。

问题 3【AllGather 成本】: 对于大小为 B 字节的数组,在一个 8xH100 节点上执行一次(受吞吐限制的)AllGather 需要多久?请把 bf16[D<sub>X</sub>, F]D=4096F=65,536 的情况也算出来。在回答之前,值得先读一读 TPU 集合通信的相关章节。这里先自己推一推,下一节我们会更深入讨论集合通信。

点击这里查看答案。 **答案:** 每张 GPU 的出口带宽是 450GB/s,每张 GPU 持有的数据量是 $B / N$ 字节(这里 `N=8`,即节点大小)。你可以把这个过程想象成:每张 GPU 依次把自己的那部分数据发给其余 $N - 1$ 张 GPU,因此总共有 $(N - 1)$ 轮,每一轮的通信时间都是 $T_\text{comms} = (B / (N * W_\text{unidirectional}))$,于是总时间为 $T_\text{comms} = (N - 1) * B / (N * W_\text{unidirectional})$。这大约等于 $B / (N * W_\text{uni})$,也就是由双向切分带宽决定的 $B / \text{3.6e12}$。 对题目中的数组,我们有 `B=4096 * 65536 * 2=512MB`,因此总时间为 `536e6 * (8 - 1) / 3.6e12 = 1.04ms`。这可能会受到延迟影响,因此实际时间可能更长一些(实际中大约是 1.5ms)。

节点之外

到了节点之外,GPU 网络的拓扑就没有那么标准了。NVIDIA 发布了一个参考 DGX SuperPod 架构,用 InfiniBand 把多于单节点规模的 GPU 连接起来,但客户和数据中心提供商也完全可以根据自己的需要进行定制。例如,Meta 训练 LLaMA-3 时所用的数据中心网络就与这里描述的结构差异很大:它使用的是 Ethernet、三层交换 fabric,以及顶部过订阅的交换机。

下面是一个参考的 1024 张 H100 GPU 系统示意图。底部一排的每个框都表示一个单独的 8xH100 节点,其中包含 8 张 GPU、8 个 400Gbps 的 CX7 NIC(每张 GPU 一个)和 4 个 NVSwitch。

<b>图:</b>参考 1024 张 H100 的 DGX SuperPod 示意图,共 128 个节点(有时是 127),每个节点有 8 张 H100 GPU,并连接到一个 InfiniBand scale-out 网络。每 32 个节点(256 张 GPU)组成一个 “Scalable Unit”,简称 SU。叶交换机和脊交换机组成的 IB 网络能在节点间提供完整的双向切分带宽。
图:参考 1024 张 H100 的 DGX SuperPod 示意图,共 128 个节点(有时是 127),每个节点有 8 张 H100 GPU,并连接到一个 InfiniBand scale-out 网络。每 32 个节点(256 张 GPU)组成一个 “Scalable Unit”,简称 SU。叶交换机和脊交换机组成的 IB 网络能在节点间提供完整的双向切分带宽。

Scalable Unit: 每 32 个节点构成一个 “Scalable Unit”(简称 SU),它们位于同一组 8 个叶层 InfiniBand 交换机之下。这个 SU 包含 256 张 GPU,每个节点有 4 个 NVSwitch,整个 SU 配有 8 个 InfiniBand 叶交换机。图中所示的所有线缆都是 InfiniBand NDR(50GB/s 全双工),交换机也都是 64 端口的 NDR IB 交换机(每端口同样是 50GB/s)。注意,IB 交换机的总带宽是 NVSwitch 的 2 倍(64 个 400 Gbps 端口)。

SuperPod: 整个 SuperPod 再用 16 个顶层 “spine” IB 交换机把 4 个这样的 SU 连在一起,从而形成一个 1024 张 GPU 的系统,其中共有 512 个节点级 NVSwitch、32 个叶层 IB 交换机以及 16 个脊层 IB 交换机,总计 512 + 32 + 16 = 560 个交换机。叶交换机以 32 个节点为一组连接到底层节点,因此每组 256 张 GPU 都对应 8 个叶交换机。所有叶交换机都连接到所有脊交换机。

总带宽有多少? 整个 InfiniBand 网络(也就是 “scale-out network”)的拓扑是一棵 胖树(fat tree),其线缆和交换机配置保证了节点级以上的完整双向切分带宽(这里是 400GB/s)。这意味着,如果把节点分成两半,每个节点都能同时以 400GB/s 的带宽向另一半中的某个节点发送数据。更重要的是,这意味着在 scale-out 网络中,我们的 AllReduce 带宽应当大致保持常数!虽然实际实现不一定真这样做,但你完全可以把它想象成在任意多节点上做一个环形归约,因为从拓扑上总能构造出覆盖所有节点的环。

层级 GPU 数 每单元交换机数量 交换机类型 每单元带宽(TB/s,全双工) GPU 到 GPU 带宽(GB/s,全双工) 胖树带宽(GB/s,全双工)
Node 8 4 NVL 3.6 450 450
Leaf 256 8 IB 12.8 50 400
Spine 1024 16 IB 51.2 50 400

相比之下,TPU v5p 每条链路的出口带宽约为 90GB/s,因此沿 3D 环面的所有轴总出口带宽约为 540GB/s。尽管它不是点对点网络,只能用于受限的、均匀的通信模式,但它仍然提供了更高的 TPU 对 TPU 带宽,而且可以扩展到极大的拓扑规模(至少到 8960 张 TPU)。

GPU 的交换网络在理论上也可以通过增加更多交换机或更多层级来扩展到任意规模,但代价是更高延迟以及昂贵的网络交换设备。

**要点:** 在一个 H100 节点内部,每张 GPU 都拥有完整的 450GB/s 胖树带宽;而一旦跨出节点,节点到节点之间就下降为 400GB/s。这一点对通信原语至关重要。

GB200 NVL72: NVIDIA 最近开始生产新的 GB200 NVL72 GPU 集群,它在单个 NVLink 域内将 72 张 GPU 以完整的 900GB/s GPU 到 GPU 带宽连接起来。然后这些域再通过成比例更高(9 倍)的 IB 胖树带宽连接成更大的 SuperPod。其拓扑如下图所示:

<b>图:</b>展示 576 张 GPU 的 GB200 DGX SuperPod 拓扑图。底层的每个机架都包含 72 张 GB200 GPU。
图:展示 576 张 GPU 的 GB200 DGX SuperPod 拓扑图。底层的每个机架都包含 72 张 GB200 GPU。

只数单个节点的出口带宽(上图中橙色的线),我们得到 4 * 18 * 400 / 8 = 3.6TB/s 的叶层带宽,这恰好是 H100 的 9 倍(节点中的 GPU 数量本来也多了 9 倍)。这意味着节点出口带宽这个关键指标大大提高,导致跨节点集合通信的带宽反而可能 低于 节点内部的通信带宽。 更多讨论见附录 A

节点类型 每节点 GPU 数 单 GPU 出口带宽 节点出口带宽
H100 8 450e9 400e9
B200 8 900e9 400e9
GB200 NVL72 72 900e9 3600e9

**要点:** GB200 NVL72 SuperPod 大幅增加了节点大小和单节点出口带宽,因此显著改变了我们的 roofline 结论。

小测验 3:节点之外

问题 1【胖树拓扑】: 利用上面的 DGX H100 图,计算整个 1024 GPU pod 在节点级别的双向切分带宽。说明每条链路的带宽为什么恰好被设计成可以提供完整的双向切分带宽。提示:请同时计算链路带宽和交换机带宽。

点击这里查看答案。 **答案:** 我们分组件来计算: * 首先,每个节点有 8 条 400Gbps 的 NDR IB 线缆连接到叶交换机,因此每节点到叶层的带宽为 `8 * 400 / 8 = 400 GB/s`。我们有 8 个叶交换机,每个带宽 3.2TB/s(64 条 400Gbps 链路),但其中只有 32 个端口用于从 SU 向交换机注入流量,因此总计 `32 * 400 / 8 = 12.8TB/s`,对于 32 个节点来说也正好是每节点 400GB/s。 * 再到脊层,每个 SU 通过 `8 * 16 * 2` 条 400Gbps 的 NDR IB 线缆连接到脊交换机,因此每个 SU 到脊层的带宽为 `8 * 16 * 2 * 400 / 8 = 12.8 TB/s`。这对每个节点来说仍然是 400GB/s。我们有 16 个脊交换机,每个 3.2TB/s,总计 `16 * 3.2 = 51.2 TB/s`,分给 128 个节点,同样还是每节点 400GB/s。 因此,无论如何对节点做二分,节点之间都将具有每 GPU 400GB/s 的带宽。每一层组件的带宽都恰好满足胖树要求。

问题 2【扩展到更大的 DGX Pod】: 假设我们希望用 2048 张 GPU 而不是 1024 张来训练。要在上面的 DGX 拓扑基础上支持这一点,最简单/最好的改法是什么?如果是 4096 张呢?提示:这没有唯一正确答案,但尽量控制成本。记得考虑链路容量。这份文档可能有帮助。

点击这里查看答案。 **答案:** 一种做法是保留现有 SU 结构不变(每 32 个节点挂在 8 个交换机下),然后继续向上加更多 SU,并增加更多顶层交换机。要扩展到 2048 张 GPU,我们需要 2 倍数量的脊交换机,也就是 8 个 SU 加 32 个脊交换机,才能提供足够带宽。 问题在于,每个叶交换机只有 64 个端口,而在上图中这些端口已经全部用满了。一个简单办法是,把每个脊交换机方向上的连接从 2 条 400Gbps NDR 线缆减为 1 条。这样总带宽不变,但会节省一部分端口。 而扩展到 4096 张 GPU 时,我们最终会遇到端口数耗尽的问题,于是必须再增加一级间接层,也就是再增加一层网络层级。NVIDIA 将这一层称为 “core switches”,并通过 128 个脊交换机和 64 个 core 交换机构建 4096 张 GPU 的集群。你可以自行验证,这样的配置能提供足够带宽。

GPU 上的集合通信如何工作?

GPU 能执行 TPU 上所有同样的集合通信:ReduceScatter、AllGather、AllReduce 和 AllToAll。与 TPU 不同的是,这些操作的具体实现会随着它们发生在节点内(走 NVLink)还是节点外(走 InfiniBand)而改变。这些集合通信由 NVIDIA 的 NVSHMEMNCCL(读作 “nickel”)库实现。NCCL 已经在这里开源。尽管 NCCL 会根据延迟需求和拓扑采用多种不同实现(见细节),从现在开始,我们讨论的是在交换式树形网络上理论上最优的模型。

节点内集合通信

AllGather 或 ReduceScatter: 在节点级别,AllGather 和 ReduceScatter 可以像 TPU 一样沿环执行,并在每一跳都利用完整的 GPU 到 GPU 带宽。你可以任意给 GPU 排一个顺序,然后把数组的一部分沿着这个环传递,每一跳都用满 GPU 到 GPU 带宽。你也可以把它理解为每个 GPU 把自己那块大小为 $\text{bytes} / N$ 的数据发给其余 $N - 1$ 张 GPU,于是总通信量是 $(N - 1) * N * bytes / N$,最终得到同样的答案。 每一跳的代价为 $T_\text{hop} = \text{bytes} / (N * \text{GPU egress bandwidth})$,因此总成本是

$$T_\text{AG or RS comms} = \frac{\text{bytes} \cdot (N - 1)}{N \cdot \text{GPU egress bandwidth}} \rightarrow \frac{\text{bytes}}{\text{GPU egress bandwidth}}$$

你会发现这与 TPU 上完全一样。对于 AllReduce,可以像往常一样把一次 RS 和一次 AG 组合起来,因此代价翻倍。

<b>图:</b>带宽最优的一维环形 AllGather 算法。对于 B 字节的数据,它会把 B / X 字节通过顶层交换机发送 X - 1 次。
图:带宽最优的一维环形 AllGather 算法。对于 B 字节的数据,它会把 B / X 字节通过顶层交换机发送 X - 1 次。

如果你更关心延迟(例如数组很小),可以改用树形归约:先在 2 个节点内做 AllReduce,再扩展到 4 个、8 个,总共只有 $\log(N)$ 跳,而不是 $N - 1$ 跳,不过总成本并不会改变。

**要点:** 在单个节点内对一个大小为 B 字节的数组执行 AllGather 或 ReduceScatter,其成本大约是 $T_\text{comms} = B * (8 - 1) / (8 * W_\text{GPU egress}) \approx B / W_\text{GPU egress}$。对于 H100,理论上约为 $B / \text{450e9}$;对 B200,则是 $B / \text{900e9}$。如果没有启用网络内归约,AllReduce 的成本是它的 2 倍。

随堂小测 1【AllGather 时间】: 在一个 8xH100 节点上,若全双工带宽为 450 GB/s,AllGather(bf16[BX, F]) 需要多久?令 $B=1024$,$F=16,384$。

点击这里查看答案。 **答案:** 总共有 $2 \cdot B \cdot F$ 字节数据,单向带宽是 450e9,因此大致耗时 $T_\text{comms} = (2 \cdot B \cdot F) / \text{450e9}$;更精确地说是 $(2 \cdot B \cdot F \cdot (8 - 1)) / (8 \cdot \text{450e9})$。代入给定数值,得到大约 $(2 \cdot 1024 \cdot 16384) / \text{450e9} = \text{75us}$,更精确则约为 $\text{65us}$。

AllToAll: 节点内的 GPU 彼此全互连,因此 AllToAll 从拓扑上讲就很简单:每张 GPU 直接把数据发给目标 GPU 即可。对于 B 字节,总共有 $B / N$ 字节在每张 GPU 上,而它会向其余 $N - 1$ 个目标节点各发送 $(B / N^2)$ 字节,因此总成本为

$$T_\text{AllToAll comms} = \frac{B \cdot (N - 1)}{W \cdot N^2} \approx \frac{B}{W \cdot N}$$

把它和 TPU 对比,TPU 上成本是 $B / (4W)$。因此,在单个节点内,GPU 理论上可把运行时间提升 2 倍($B / 4W$ 对比 $B / 8W$)。

对于混合专家(MoE)模型,我们经常需要做的是 稀疏或不规则的 AllToAll,即保证输出维度上 N 个 shard 中最多只有 $k$ 个非零,也就是说 $T_\text{AllToAll} \rightarrow K[B, N]$,并且每个轴上最多只有 $k$ 个位置非零。这样成本会按 $k/N$ 缩小,总成本大约是 $\min(k/N, 1) \cdot B / (W \cdot N)$。对于 MoE,我们通常是独立随机地选择这些非零值,因此有一定概率实际少于 $k$ 个非零,故更准确地近似为 $(N-1)/N \cdot \min(k/N, 1) \cdot B / (W \cdot N)$。更精确的成本实际上是 $$(1 - \left(\frac{Z - 1}{Z}\right)^K) \cdot \frac{Z - 1}{Z}$$,也就是 K 次掷骰子后不同结果个数的期望,但它与这里给出的近似非常接近。更多细节见附录。

随堂小测 2【AllToAll 时间】: 在一个 8xH100 节点上,若单向带宽为 450 GB/s,AllToAllX->N(bf16[BX, N]) 需要多久?如果我们知道 8 个条目中只有 4 个会是非零,又会怎样?

点击这里查看答案。 **答案:** 根据上面的结论,在稠密情况下,成本是 $B \cdot (N-1) / (W \cdot N^2)$,也就是 $B / (W \cdot N)$。如果我们知道只有 $\frac{1}{2}$ 的条目不会是 padding,那么我们只需发送 $B \cdot k/N / (W \cdot N) = B / (2 \cdot W \cdot N)$,约为原始成本的一半。

**要点:** 对于单节点 GPU 上大小为 $B$ 字节的数组,AllToAll 的成本约为 $T_\text{comms} = (B \cdot (8 - 1)) / (8^2 \cdot W_\text{GPU egress}) \approx B / (8 \cdot W_\text{GPU egress})$。对一个不规则(top-$k$)AllToAll,这一成本还会进一步降到 $(B \cdot k) / (64 \cdot W_\text{GPU egress})$。

经验测量: 下图展示了一个 8xH100 节点上 AllReduce 带宽的经验测量。Algo BW 是测得的带宽(字节数/运行时间),Bus BW 则按 $2 \cdot W \cdot (8 - 1) / 8$ 计算,可视作链路实际带宽的理论指标。你会看到,我们的确能接近 370GB/s,这虽然比 450GB/s 低,但也算接近了,而且这还是在每设备只有约 10GB 数据量时才达到的。这意味着:尽管这些估算在理论上是正确的,但要真正实现它,往往需要相当大的消息。

<b>图:</b>关闭 SHARP 时,一个 8xH100 节点的 AllReduce 吞吐。蓝色曲线是根据实测结果计算出的链路带宽,即 $2 * \text{bytes} * (N - 1) / (N * \text{runtime})$。可以看到,即便在 10GB 这样的大数组上,我们也并没有特别接近宣称的 450GB/s 带宽。
图:关闭 SHARP 时,一个 8xH100 节点的 AllReduce 吞吐。蓝色曲线是根据实测结果计算出的链路带宽,即 $2 * \text{bytes} * (N - 1) / (N * \text{runtime})$。可以看到,即便在 10GB 这样的大数组上,我们也并没有特别接近宣称的 450GB/s 带宽。

这确实是个问题,因为它实质上让理论分析变得复杂得多。比如,即使是像 LLaMA-3 70B 的 MLP 那样一个相当合理大小的数组(bf16[8192, 28672],如果做 8 路模型切分则是 bf16[8192, 3584] = 58MB),AllReduce 也只能做到大约 150GB/s,而不是峰值 450GB/s。相比之下,TPU 在更小的消息规模上就能达到峰值带宽(见附录 B)。

**要点:** 尽管 NVIDIA 声称 H100 NVLink 的带宽约为 450GB/s,但在实践中很难超过 370 GB/s,因此在用上述估算做判断时,应按这一现实情况进行修正。

网络内归约: 从 Hopper 开始,NVIDIA 的交换机支持 "SHARP"(Scalable Hierarchical Aggregation and Reduction Protocol),允许执行“网络内归约”。这意味着 网络交换机本身 可以执行归约操作,并把结果多路复用或 “MultiCast” 到多个目标 GPU:

<b>图:</b>不使用 SHARP 的 AllReduce 理论成本是使用它时的 2 倍,因为数据必须两次经过每张 GPU。实际中速度提升只有约 30%(来自 NCCL 2.27.5)。
图:不使用 SHARP 的 AllReduce 理论成本是使用它时的 2 倍,因为数据必须两次经过每张 GPU。实际中速度提升只有约 30%(来自 NCCL 2.27.5)。

理论上,这几乎能把 AllReduce 的成本减半,因为每张 GPU 只需要把数据送到顶层交换机,由交换机完成归约并向各目标 GPU 广播结果,而不必让每张 GPU 两次把数据送出,同时还能减少网络延迟。

$$T_\text{SHARP AR comms} = \frac{\text{bytes}}{\text{GPU egress bandwidth}}$$

注意这里是精确等式,不是差一个 $1/N$ 因子,因为每张 GPU 先出口 $B \cdot (N - 1) / N$,然后接收入站的局部分片归约结果(入口 $B/N$),在本地完成剩余归约后,再把 $B/N$ 出口一次,最后再接收入站的完整归约结果 $B \cdot (N - 1) / N$,结果正好总计入口为 $B$ 字节。

不过在实践中,开启 SHARP 后我们看到的带宽提升大约只有 30%,而理论上本该接近 75%。这意味着有效集合通信带宽只是提升到了约 480GB/s,远没有接近 2 倍。

<b>图:</b>在节点内开启和关闭 NVIDIA SHARP 时,AllReduce algo 带宽的经验测量。尽管理论上它本应实现接近 75% 的收益,但峰值吞吐提升实际只有约 30%。
图:在节点内开启和关闭 NVIDIA SHARP 时,AllReduce algo 带宽的经验测量。尽管理论上它本应实现接近 75% 的收益,但峰值吞吐提升实际只有约 30%。

**要点:** 理论上,NVIDIA SHARP(大多数 NVIDIA 交换机都支持)应把大小为 $B$ 字节的 AllReduce 成本从约 $2 * B / W$ 降到 $B / W$。但在实践中我们只看到大约 30% 的带宽提升。由于纯 AllReduce 在 LLM 中本来也不算特别常见,因此这项能力的实际帮助并没有那么大。

跨节点集合通信

跨出节点后,成本会稍微复杂一些。对于树形网络上的归约,你可以把它理解为从下往上做归约:先在节点内,再到叶层,最后到脊层,并在每一层都使用正常的算法。尤其对 AllReduce 来说,这样理解还能帮助我们看到总通信量其实变少了,因为在节点内先做完 AllReduce 后,向叶层出口的只剩下 $B$ 字节,而不是 $B * N$。

这到底有多贵? 第一近似下,由于我们有完整的双向切分带宽,AllGather 或 ReduceScatter 的成本大约等于缓冲区字节数除以节点出口带宽(H100 上是 400GB/s),而与树形归约的许多细节几乎无关。

$$T_\text{AG or RS comms} = \frac{\text{bytes}}{W_\text{node egress}} \underset{H100}{=} \frac{\text{bytes}}{\text{400e9}}$$

其中 $W_\text{node}$ egress 对于上述 H100 网络通常就是 400GB/s(每节点有 8 条 400Gbps 的 IB 链路对外)。理解这一点最简洁的方式,是把它想象成在 整个集群的每一个节点 之间做一个环形归约。由于胖树拓扑的存在,我们总能构造出一个环,让任意相邻节点之间都有 $W_\text{node}$ 的出口带宽,然后按常规方法做归约。节点内归约几乎永远不会成为瓶颈,因为它带宽更高、延迟也更低。更一般地,成本是

$$T_\text{total} = \max(T_\text{comms at node}, T_\text{comms in scale-out network}) = \max\left[\frac{\text{bytes}}{W_\text{GPU egress}}, \frac{\text{bytes}}{W_\text{node egress}}\right]$$
这里给出更精确的推导。 更精确地说,我们实际上是在网络的每一层上都做一个环形归约,而且这些过程大多可以重叠,因此有: $$T_\text{AG or RS comms} = \text{bytes} \cdot max_\text{depth i}\left[\frac{D_i - 1}{D_i \cdot W_\text{link i}}\right]$$ 其中 $D_i$ 是深度 $i$ 处的度数(该层每个节点的子节点数),$W_\text{link i}$ 是每个子节点连接到第 $i$ 层节点时的链路带宽。 利用这一式子,我们可以计算某个给定拓扑下 AllGather/AllReduce 的可用带宽,也就是 $min_\text{depth i}(D_i * W_\text{link i} / (D_i - 1))$。在前面的例子里: * **Node:** $D_\text{node}$ = 8,因为一个节点里有 8 张 GPU,而 $W_\text{link i}$ = 450GB/s。因此 AG 带宽为 `450e9 * 8 / (8 - 1) = 514GB/s`。 * **Leaf:** $D_\text{leaf}$ = 32,因为一个 SU 中有 32 个节点,而 $W_\text{link i}$ = 400GB/s(来自每节点 8 条 400Gbps 的 IB 链路)。因此带宽为 `400e9 * 32 / (32 - 1) = 413GB/s`。 * **Spine:** $D_\text{spine}$ = 4,因为有 4 个 SU,且 $W_\text{link i}$ = 12.8TB/s(来自前文的 `8 * 16 * 2 * 400Gbps` 链路)。因此带宽是 `12.8e12 * 4 / (4 - 1) = 17.1TB/s`。 所以总体的 AG 或 RS 带宽是 `min(514GB/s, 413GB/s, 17.1TB/s) = 413GB/s`,瓶颈位于叶层。也就是说,在实践中 $T_\text{AG or RS comms} = B / \text{413GB/s}$;换言之,即便在最高层,我们也大约拥有 413GB/s 的 AllReduce 带宽。对于启用了 SHARP 的 AllReduce,这个数会略低一些(约 400GB/s),因为那时没有 $(N - 1) / N$ 这个因子。不过,450GB/s 和 400GB/s 已经足够接近,做近似时完全可以直接用它们。

其他集合通信: 如果没有启用 SHARP,AllReduce 的成本仍然是上述的 2 倍。NVIDIA 也销售支持 SHARP 的 IB 交换机,不过并不是所有云厂商都提供。跨节点时,AllToAll 则变化很大,因为它不像 AllReduce 那样天然具有“层次性”。如果我们想从每张 GPU 给每一张其他 GPU 发送数据,就无法像归约操作那样利用节点层面的完整双向切分带宽。也就是说,如果一个 $N$ 路 AllToAll 横跨 $M = N / 8$ 个节点,那么其成本为

$$T_\text{AllToAll comms} = \frac{B \cdot (M - 1)}{M^2 \cdot W_\text{node egress}} \approx \frac{B}{M \cdot W_\text{node egress}}$$

这等效于带宽只有 50GB/s,而不是 400GB/s。我们会从单个 H100 节点内部的 $B / (8 * \text{450e9})$,退化到跨 2 个节点时的 $B / (2 \cdot \text{400e9})$,也就是性能下降 4 倍以上。

下面总结了 1024-GPU DGX H100 SuperPod 的架构:

层级 GPU 数量 度数(子节点数) 交换机带宽(全双工,TB/s) 线缆带宽(全双工,TB/s) 集合通信带宽(GB/s)
Node 8 8 6.4 3.6 450
Leaf(SU) 256 32 25.6 12.8 400
Spine 1024 4 51.2 51.2 400

我们用 “Collective Bandwidth” 来表示 GPU 或节点的有效出口带宽。它也等于 $\text{bisection bandwidth} * 2 / N$。

**要点:** 跨出节点后,对大小为 B 字节的数组执行 AllGather 或 ReduceScatter,其成本大致是 $B / W_\text{node egress}$;在 H100 DGX SuperPod 上,这就是 $B / \text{400e9}$。而 AllReduce 若未启用 SHARP,则成本翻倍。整体拓扑是一棵胖树,旨在确保任意两对节点之间都具有恒定带宽。

当数组本身还沿另一条轴被切分时的归约: 考虑如下归约成本

$$\text{AllReduce}_X(A[I_Y, J]\ \{ U_X \})$$

也就是说,我们正在对一块本身还沿另一条轴 $Y$ 被切分的数组做 AllReduce。对 TPU 而言,由于每条轴上传输的数据量减少为原来的 $1 / Y$,所以该操作的整体成本会相应降低 $1 / Y$。在 GPU 上,成本取决于哪条轴是“内层轴”(节点内还是节点间),以及每个 shard 是否跨越了多个节点。假设 $Y$ 是内层轴,且数组总共有 $\text{bytes}$ 字节,那么总成本会随着 $Y$ 增大而降低,但只有在 $Y$ 跨越多个节点时才真正改善总时间:

$$T_\text{comms at node} = \frac{\text{bytes}}{W_\text{GPU egress}} \cdot \frac{1}{\min(Y, D_\text{node})}$$ $$T_\text{comms in scale-out network} = \frac{\text{bytes}}{W_\text{node egress}} \cdot \frac{D_\text{node}}{\max(D_\text{node}, Y)}$$ $$T_\text{total} = \max(T_\text{comms at node}, T_\text{comms in scale-out network})$$

这里 N 是 GPU 数量,而 $D_\text{node}$ 再次表示每个节点中的 GPU 数(也就是节点的度数)。你可以看到,如果 $Y < D_\text{node}$,我们只是在节点层面获益,但整体运行时间通常不会下降;而如果 $Y > D_\text{node}$,我们则会获得与跨越节点数成比例的加速。

如果要精确描述树形 AllGatherX(AY { UX })(仍假设 Y 是内层轴)的环形归约规则,则有

$$T_\text{AR or RS comms} = \text{bytes} \cdot \max_{\text{depth } i}\left[\frac{D_i - 1}{D_i \cdot \max(Y, S_{i-1}) \cdot W_{\text{link } i}}\right]$$

其中 $S_i$ 是树上第 $i$ 层之下的子节点规模,也就是 M * N * …。这大致是在说:你跨越的 GPU 或节点越多,能利用到的带宽就越大,但这种提升只会体现在该层级之内。

随堂小测 3【沿两条轴切分】: 假设我们要在单个 SU(256 张芯片)上执行 $\text{AllGather}_X(\text{bf16}[D_X, F_Y])$,且 $Y$ 是内层轴。其耗时关于 $D$、$F$ 和 $Y$ 的函数形式是什么?

点击这里查看答案。 **答案:** 我们可以分两种情况讨论:$Y <= 8$ 和 $Y > 8$。当 $Y <= 8$ 时,瓶颈仍在叶交换机,因此答案依旧是 $T_\text{comms} = 2 * D * F * (32 - 1) / (32 * 400e9)$。当 $Y > 8$ 时,根据上面的推导,大致有 $$T_\text{comms} = \frac{2 \cdot D \cdot F \cdot 256}{Y \cdot \text{12.8e12}} = \frac{2DF}{Y \cdot \text{50GB/s}}$$ 对于 `D = 8192`、`F = 32,768`,结果如下图所示:
<b>图:</b>随着内层轴跨越更多节点,一个分片 AllGather 的理论成本。
图:随着内层轴跨越更多节点,一个分片 AllGather 的理论成本。
注意,如果我们恰好做 8 路模型并行,那么节点级归约成本的确会下降 8 倍,但整体成本并不会因此改变,也就是说它“免费”,却并不能改善总带宽。

**要点:** 当我们沿多条轴切分时,外层归约的成本会按内层轴跨越的节点数相应下降。

小测验 4:集合通信

问题 1【SU 上的 AllGather】: 只考虑一个单独 SU,其中有 M 个节点、每节点 N 张 GPU。一次 AllGather 过程中,节点级交换机会精确地接收和发送多少字节?顶层交换机又是多少?

点击这里查看答案。 **答案:** 我们一步一步来看这个归约过程中各组件的流量: 1. 每张 GPU 向交换机发送 $B / MN$ 字节,因此总入口量是 $NB / MN = B / M$。 2. 交换机再把完整的 $B / M$ 字节向上发送到脊交换机。 3. 它再从脊交换机接收 $B * (M - 1) / M$ 字节。 4. 最后它向 N 张 GPU 各发送 $B - B / MN$ 字节,总计 $N * (B - B / MN) = NB - B / M$。 因此总入口量为 $B$,总出口量为 $BN$,所以瓶颈应在出口端,总时间为 $T_\text{AllGather} = BN / W_\text{node} = B / \text{450e9}$。 对于脊交换机,计算其实更简单。它必须接收 $B / M$ 字节共 M 次(总计 $B$ 字节),然后向外发送 $B (M - 1) / M$ 字节共 M 次,总计 $B * (M - 1)$ 字节。由于后者明显更大,因此其成本为 $T_\text{AllGather} = B \cdot (M - 1) / (M \cdot W_\text{node}) = B \cdot (M - 1) / (M \cdot \text{400e9})$。

问题 2【单节点 SHARP AllReduce】: 设一个单节点内有 N 张 GPU。使用 SHARP(网络内归约)做 AllReduce 时,交换机会精确地接收和发送多少字节?

点击这里查看答案。 **答案:** 仍然一步一步来看: 1. 每张 GPU 发送 $B * (N - 1) / N$ 字节,因此交换机入口共为 $N * B * (N - 1) / N = B * (N - 1)$。 2. 交换机累加局部和后,给每张 GPU 发送 $B / N$ 字节,因此总出口量为 $N * B / N = B$。 3. 然后 GPU 在本地对剩余部分做部分归约,再把这些数据发回交换机,因此又有总计 $N * B / N = B$ 字节入口。 4. 最后交换机收齐各 shard 后做 multicast,向 N 个目标各发送 $B * (N - 1) / N$,总计出口 $B * (N - 1) / N * N = B * (N - 1)$。 因此总入口量和总出口量都为 $B * (N - 1) + B = BN$ 字节。这也支持了整体吞吐恰好等于 $B / W_\text{egress}$ 这一结论。

问题 3【跨节点 SHARP AllReduce】: 考虑一个 bf16[D<sub>X</sub>, F<sub>Y</sub>] 数组,沿 X 轴在单个节点的 N 张 GPU 上切分。AllReduce(bf16[D, FY] { UX }) 需要多久?可以假设我们启用了网络内归约。并解释如果跨越多个节点,情况会有什么不同。

点击这里查看答案。 **答案:** 我们可以在前一个问题的答案上稍作修改。基本上,每张 GPU 先出口 $B * (X - 1) / XY$ 字节,再接收 $B / XY$,随后再把这部分回送到交换机,最后再接收 $B * (X - 1) / XY$。因此总入口和出口量都是 $NB / Y$,总时间就是 $T_\text{comms} = NB / (Y * N * W_\text{link}) = N * 2DF / (Y * N * W_\text{link}) = 2 * D * F / (Y * W_\text{link})$,所以总时间确实会随着 $Y$ 增大而下降。 如果扩展到多个节点之外,我们可以做大致相同的归约,但当节点级交换机向外出口时,它必须发送全部 B 字节,而不只是 $B / Y$。这是因为我们必须把各个 shard 保持分离。

问题 4【脊层 AR 成本】: 考虑与上面相同的设置,但令 $Y = 256$(即 AllReduce 发生在脊层)。此时 AllReduce 需要多久?同样,你可以假设启用了网络内归约。

点击这里查看答案。 **答案:** 这时我们可以利用脊层上夸张得近乎离谱的带宽。4 个节点之间总共有 25.6TB/s 带宽,因此 AllReduce 带宽为 6.4TB/s。启用 SHARP 的情况下,它理论上可以低到 `2 * D * F / 6.4e12` 秒。

问题 5【2 路 AllGather 成本】: 精确计算跨 恰好 2 个节点 的 $B$ 字节 AllGather 成本。请给出精确成本而不仅是近似值,并同时考虑节点内和跨节点的成本。

点击这里查看答案。 **答案:** 在节点内,我们有 $T_\text{comms} = B * 7 / (8 * \text{450e9}) = B / \text{514e9}$;而跨节点时,实际上有 $T_\text{comms} = B * (2 - 1) / (2 * \text{400e9}) = B / \text{800e9}$。因此真正的瓶颈其实是节点内归约,而不是叶层!这也是诸如 DeepSeek v3 之类工作会采用 2 路数据并行的动机之一。

GPU 上 LLM 扩展的 Roofline

现在我们来看本章的核心目标:理解在 GPU 上扩展 LLM 的 roofline。这一部分可以与 TPU 训练章节中的对应部分互为补充。和那一章一样,我们的目标是比较不同并行策略下总的 $T_\text{math}$ 与 $T_\text{comms}$,并理解在什么条件下 $T_\text{comms} > T_\text{math}$。同样地,我们这里只考虑 MLP 模块,其操作为

$$\text{MLP}(x) \equiv x[B, D] *_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D]$$

其中 $B$ 是以 token 计的全局 batch size(即 $B = \text{batch size} \cdot \text{sequence length}$)。

我们先再次给出前面那张汇总表,其中列出了 GPU 级和节点级的有效带宽:

节点类型 每节点 GPU 数 单 GPU 出口带宽 节点出口带宽
H100 8 450e9 400e9
B200 8 900e9 400e9
GB200 NVL72 72 900e9 3600e9

注意: GPU 出口带宽和节点出口带宽都会决定 LLM 的 roofline。我们接下来用 $W_\text{collective}$ 来表示这两者之一,具体取决于当前操作发生在节点内还是节点外。

下面我们像分析 TPU 那样,依次分析 数据并行、张量并行、流水线并行、专家并行 及其组合的计算/通信 roofline。在本节后续的具体计算中,我们将聚焦于 H100。GB200-NVL72 的总体结论是相同的,只是由于节点出口带宽更大,有时真正的瓶颈会落在节点内而不是 scale-out 网络。

数据并行

正如前面提到过的,DP 和 ZeRO 切分在反向传播中都需要做一次权重 AllReduce,或者一次 ReduceScatter 加一次 AllGather。由于这两者成本相同,因此如果想让纯数据并行或 FSDP 在不开启网络内归约的情况下保持计算受限,那么对于大小为 X 的某条轴、每一层的反向传播,有:

$$T_\text{math} = \frac{2 \cdot 2 \cdot 2 \cdot BDF}{X \cdot C}$$ $$T_\text{comms} = \frac{2 \cdot 2 \cdot 2 \cdot DF}{W_\text{collective}}$$

因此,要满足 $T_\text{math} > T_\text{comms}$,我们需要 $B / (XC) > 1 / W_\text{collective}$,即

$$\frac{B}{X} > \frac{C}{W_\text{collective}}$$

这里的 $W_\text{collective}$ 取决于切分是在节点内还是跨节点。因此:

这比 TPU 高得多;在 TPU 上,如果同时使用三条轴,这个数字大约只有 850。比如,LLaMA-3 在 16000 张 H100 上训练时,每张 GPU 若想进入计算受限区,至少需要 40M token 的 batch size(而他们实际使用的是 16M)。又比如,DeepSeek v3 在 2048 张 H800(其带宽只有 300GB/s,而不是 H100 的 450GB/s)上训练,则每张 GPU 需要达到 $\text{990e12} / \text{300e9} = 3300$ 个 token,也就是总共约 6.7M(而实际上他们用的是 4M)。

如果启用网络内归约,并且采用纯数据并行,理论上我们会拥有 2 倍 AllReduce 带宽,因此上述两个阈值都会减半。不过在实践中,收益更接近 30%,它更多只是用来弥补我们通常无法真正达到标称带宽这一事实。而且纯数据并行很少单独使用,所以这在实践中几乎不重要。

MoE 模型: 对于混合专家(MoE)模型,若我们有 E 个专家、每个 token 激活 k 个专家,则有

$$T_\text{math} = \frac{2 \cdot 2 \cdot 2 \cdot k \cdot BDF}{X \cdot C}$$ $$T_\text{comms} = \frac{2 \cdot 2 \cdot 2 \cdot EDF}{W_\text{collective}}$$

这会把每张 GPU 所需的 token batch size 再放大一个 $E/k$ 因子,即

$$\frac{B}{X} > \frac{E}{k} \frac{C}{W_\text{collective}}$$

例如,对于新版 OpenAI OSS 模型,若 $k=4$、$E=128$,则跨节点时这个阈值会变成 32 * 2475 = 79,200,可以说高得离谱。

当 X 很小时会怎样? 如果我们只做例如 2 节点数据并行,那么由于存在 $(X - 1) / X$ 的缩放,我们有

$$T_\text{math} = \frac{2 \cdot 2 \cdot 2 \cdot BDF}{N * C}$$ $$T_\text{comms} = \frac{2 \cdot 2 \cdot 2 \cdot DF \cdot (X-1)}{X \cdot W_\text{collective}}$$

这里 X 是节点数,而 $N = 8 \cdot X$。于是对稠密模型来说,我们只需要 $B / N > \alpha \cdot (X - 1) / X$,例如在 2 节点情况下,$B / N > \text{1237}$,也就是前面阈值的一半。正因如此,你会相当频繁地看到 2 路数据并行。

**要点:** 在 H100 或 B200 上,数据并行和 ZeRO 切分若想保持计算受限,每张 GPU 大约需要 2500 个 token 的 batch size,前提是计算与通信完美重叠且 FLOPs 利用率理想。对 MoE 模型而言,这一数字还会再乘上 $E / k$,即总参数数与被激活参数数之比。若数据并行规模较小,则临界 batch size 会下降。

张量并行

张量并行要求我们对激活做一次 AllGather 和一次 ReduceScatter,并把它们与 MLP 的 FLOPs 重叠起来。也就是说,在前向传播中有

$$T_\text{math} = \frac{2\cdot 2 \cdot BDF}{Y \cdot C}$$ $$T_\text{comms} = \frac{2\cdot 2 \cdot BD}{W_\text{collective}}$$

要保持计算受限,就得到规则

$$Y < \frac{F \cdot W_\text{collective}}{C}$$

在节点内,这大约等于 $F / 2200$;跨节点则是 $F / 2475$。对于像 LLaMA-3 那样 F=\text{28000} 的模型,这大约对应 11 路 TP(向下取整,差不多就是 8 路,也正好是一整个节点)。同样地,当我们恰好跨 2 个节点时,会额外得到 2 倍带宽,因此通常可以做到 16 路张量并行($F > 2475 \cdot (Y - 8)$),理论上甚至可以做到大约 19 路模型并行。

**要点:** 若前馈层维度为 F,那么大小为 Y 的张量并行轴在 $Y > F / 2475$ 时会进入通信受限区,因此通常只能局限在单个 NVLink 域内,也就是单节点内,最多扩展到 2 个节点。

专家并行

正如前文所述,混合专家(MoE)模型会让模型权重数量增加到原来的 E 倍,但 FLOPs 只增加到 k 倍,因此数据并行会变得显著更难。一个缓解方法,是沿专家维度切分权重,也就是令 Win[EZ, D, F]。为了执行 MLP 块,我们需要额外引入两次 AllToAll,把激活发送到对应的专家。

正如前面分析过的,如果这个 AllToAllZ->k([B, D, k]) 跨越多个节点,那么它的成本大约为 $T_\text{AllToAll} = 2 \cdot B \cdot D \cdot (Z-8)/Z \min(8 * k / Z, 1)$,因此对于纯专家并行,我们需要

$$T_\text{math} = \frac{4 \cdot B \cdot k \cdot D \cdot F}{Z \cdot C}$$ $$T_\text{comms} = \frac{4 \cdot B \cdot D \cdot (Z-8)}{W \cdot Z} \cdot \min\left(\frac{8 \cdot k}{Z}, 1\right)$$

于是我们有两种可能:要么 $K > Z/8$ 且 $F > \alpha \cdot (Z - 8)/k$;要么 $Z \gg K$ 且 $F > 8 \cdot \alpha$,其中 $\alpha = C/W$。这意味着专家并行有两个可行区域:一种是专家并行规模较小(大致跨 2 个节点),且 F 也较小;另一种则是 F 足够大,这时可以做大规模专家并行,甚至任意扩展到 E 路。

实践中两种情况都能看到:要么是少量专家并行(比如 DeepSeek v3,其 F 很小,因此跨节点专家并行规模较小且受限);要么是 F 很大,这时就可以在 TP 的同时执行大量跨节点 EP。

**要点:** 如果 $F < 8 * C / W_\text{node}$,那么专家并行最多只能跨 1 到 2 个节点,其成本与 TP 相近但略低;而如果 $F > 8 * C / W_\text{node}$,我们就可以以相对较低的代价执行大量专家并行(最多到 $E$ 个节点)。

流水线并行

流水线并行把不同层分散到不同节点上,通信成本极低,因为我们每隔几层只需传递一小批 microbatch 的激活值。历史上,流水线并行受 “pipeline bubbles” 困扰,但随着新的零气泡流水线方案出现,这个问题通常已可以解决。

流水线的总体通信成本非常小:设有 $N_\text{MB}$ 个 microbatch、$N_\text{stages}$ 个阶段,则有 $T_\text{comms per hop} = 2 \cdot B \cdot D / (W \cdot N_\text{MB})$,总共需要经历 $N_\text{MB} + N_\text{stages} - 2$ 次 hop,因此大致有

$$T_\text{total PP comms} = \frac{2BD}{W \cdot N_\text{MB}} \cdot (N_\text{MB} + N_\text{stages} - 2)$$ $$T_\text{per-layer comms} \approx 1.5 \cdot \frac{2BD}{W \cdot N_\text{layers}}$$

因为这里还要除以 $N_\text{layers}$,所以它远小于其他任何通信成本。换句话说,从通信角度看,流水线几乎是“免费的”。那为什么我们不总是用流水线并行?主要有几个原因:

(1) 代码复杂度: 流水线并行不像其他方法那样容易融入自动并行框架(例如 XLA 的 GSPMD)。因为它引入了 microbatch 来隐藏 pipeline bubble,会改变程序结构;而定制的零气泡流水线调度又进一步加剧了这一点,因为它要求前向和反向复杂交错。

(2) 流水线会让数据并行和 FSDP 变得困难: 也许最主要的原因是,流水线与 FSDP、数据并行的配合很差。尤其是 ZeRO-3,因为它要求我们在每个 microbatch 上都 AllGather 权重;但当每个 microbatch 只有 $B / N_\text{microbatches}$ 个 token 时,这样的 AllGather 成本根本无法摊薄。此外,在反向传播时,在最后一个 microbatch 通过某个 stage 之前,我们都无法对该 stage 的梯度做 AllReduce 或 ReduceScatter,这意味着会出现大量无法重叠的通信时间。

<b>图:</b>一个 2 stage、2 microbatch 的流水线示例。F 表示某个 stage 的前向,B 表示某个 stage 的反向(成本是前向的 2 倍),G 表示数据并行 AllReduce,这部分时间可能明显长于单个 microbatch。
图:一个 2 stage、2 microbatch 的流水线示例。F 表示某个 stage 的前向,B 表示某个 stage 的反向(成本是前向的 2 倍),G 表示数据并行 AllReduce,这部分时间可能明显长于单个 microbatch。

(3) 流水线气泡与步间不平衡: 如上面那个(不太好的)流水线调度所示,朴素流水线很容易产生显著的气泡(也就是计算资源空转)。在上图中,第 2 个 stage 在 step 0 空闲,第 1 个 stage 在 step 2 到 3 之间空闲,而第 2 个 stage 在最后一步再次空闲。尽管我们可以通过精心调度来部分避免这些问题,但通常仍会留下一些气泡。而且我们还必须在关键路径上把激活从一个 stage 传到下一个 stage,这也会增加开销:

<b>图:</b>一个展示传输成本(红色)的流水线示例。这会导致不同 stage 相对错位,并增加流水线气泡开销。
图:一个展示传输成本(红色)的流水线示例。这会导致不同 stage 相对错位,并增加流水线气泡开销。

这些问题都有对应的解决办法,但它们往往实现复杂、维护困难;尽管如此,相比其他方法,流水线依然是一种通信成本极低的技术。

关于延迟的一个提醒: 正如前面提到的,GPU 即便在消息相当大时,也很难达到满额 AllReduce 带宽。这意味着,哪怕从理论上看我们能够把专家并行的 AllToAll 扩展到多个节点,实践中也可能连总带宽的 50% 都达不到。因此,我们通常仍会尽量把 TP 或 EP 控制在较少的节点内,以降低延迟开销。

示例

DeepSeek 是怎么做的? 作为参考,DeepSeek V3 使用 2048 张 H800 GPU 训练,其并行策略是:

它们的稳态 batch size 是 4096 * 15360 = 62,914,560 个 token,也就是每张 GPU 约 3 万个 token。可以看到,这已经相当大了,但它们的模型也非常稀疏(k=8, E=256),因此确实需要相当大的 batch size。64 路 EP 和 16 路 PP 叠加后,总模型并行度达到 1024 路,因此 AllReduce 发生在脊层;而且由于数据并行只有 2 路,实际上还能享受 $2 / (2 - 1) = 2$ 倍的有效带宽。这也有助于降低最终与流水线最后几个阶段重叠的那次数据并行 AllReduce 成本。

LLaMA-3 是怎么做的? LLaMA-3 用 1.6 万张 GPU 训练,batch size 为 16M token,也就是每张 GPU 约 1k token。它们采用:

这还是一个稠密模型,因此总体上这些结论都比较直接。16 路 PP 会把数据并行 AllReduce 的成本降低 16 倍,从而显著降低临界 batch size。

GPU 上 LLM 扩展速览

让我们退一步,总结一下目前学到的内容:

从更高层面看,这为我们在 GPU 上切分大模型提供了一个配方:

小测验 5:LLM Roofline

问题 1【B200 的 roofline】: 一套 B200 DGX SuperPod(不是 GB200 NVL72)在节点内拥有 2 倍带宽(900GB/s 出口),但 scale-out 网络带宽与 H100 相同,仍是 400GB/s(来源)。而总 FLOPs 如前文所列。模型并行与数据并行的 roofline 会如何变化?

点击这里查看答案。 **答案:** bfloat16 FLOPs/s 从 990 提升到 2250 TFLOPs,增幅为 2.25 倍。由于节点内带宽也翻倍,因此节点内的 roofline 基本保持不变。以 TP 为例,临界强度变为 `2250e12 / 900e9 = 2500`,因此约束是 $Y < F / 2500$,只比 H100 稍高一点(而在节点大小不变的情况下,这其实帮助并不大)。 但一旦跨出节点,额外带宽没有同步增长,反而让我们更难保持计算受限!例如对于数据并行,临界 batch size 将增加到 `2250e12 / 400e9 = 5625`,因为 GPU 的 FLOPs 大幅增长了,而可用带宽却不变。 GB200 SuperPod 通过使用 72-GPU 节点并增加出口带宽,改变了这一点([来源](https://docs.nvidia.com/dgx-superpod/reference-architecture-scalable-infrastructure-gb200/latest/network-fabrics.html#compute-fabric-576))。

问题 2【如何切分 LLaMA-3 70B】: 考虑使用 bfloat16 训练、fp32 Adam 优化器状态的 LLaMA-3 70B。

  1. 如果只是为了存下权重和优化器状态,最少需要多少张 H100?
  2. 假设我们想在 4096 张 H100 GPU 上训练 15T token。若达到 45% MFU(Model FLOPs Utilization),训练需要多久?
  3. LLaMA-3 70B 的 F = 28,672,训练 batch size 约 4M token。若想在不进入通信受限的前提下,最多能做多大的模型并行?在此基础上配合纯 DP,能否在 4k 张卡上保持计算受限?ZeRO-3 呢?如果再加上 8 路流水线并行呢?注意:请同时考虑通信成本和 GPU 显存。
点击这里查看答案。 1. 权重需要 2 字节,优化器状态需要 8 字节,因此总共至少需要 700GB。每张卡有 80GB DRAM,所以至少需要 9 张 GPU,也就是向上取整至少 2 个 8xH100 节点。这当然训练起来会极慢,而且还放不下梯度检查点,但这是一个下界。 2. 总共需要 `6 * 70e9 * 15e12 = 6.3e24 bf16 FLOPs`。每张 GPU 可提供 `990e12` FLOPs,若 MFU 为 45%,则整体可达到 1.8e18 FLOPs/s。因此总训练时间约为 3.5e6 秒,即 40 天。 3. 在节点内,我们有 450GB/s 带宽,因此上限大约是 `F / 1995 = 28672 / 1995 = 14.372`。由于这不足以跨越 2 个节点,因此实际上最多也就是 8 路模型并行。 1. 这意味着需要做 512 路 DP。先看显存是否够:模型只切了 8 份,因此 `700GB / 8 = 87.5GB / GPU`,放不下,所以不行。 2. 如果使用 8 路 TP 加 ZeRO-3,那么我们会做 512 路 ZeRO-3。显存不会有问题,因为所有东西都切得很激进。但每张 GPU 的 batch size 只有 `4e6 / 4096 = 976`。这已经低于纯 DP 的阈值,而 ZeRO-3 还要搬运权重,因此阈值实际上还会再乘 2。所以也不行。 3. 如果再加上 8 路流水线,那么每个模型并行 shard 将跨 8 个节点。正如前面看到的,这会把叶层 AllGathers 的成本降低 8 倍,因此整体 AllReduce/AllGather 带宽从 400GB/s 提升到 `8 * 400GB/s = 3200GB/s`。这时 roofline 变为 `990e12 / 3200e9 = 309`,所以就没问题了!当然,前提是你能把流水线高效实现出来。

问题 3【Megatron-LM 超参数】: 考虑 Megatron-LM 仓库中的这张图,它展示了其很高的 MFU。

assets/gpu/megatron-hparams.png

注意,它们的 sequence length 到处都是 4096。对于 16B、70B 和 314B 模型,每张 GPU 的 token batch size 分别是多少?假设数据并行是最外层轴,且归约用的是 bfloat16,判断每种配置理论上是计算受限还是通信受限,以及是否存在更优配置?

点击这里查看答案。 **答案:** 先算每张 GPU 的 batch size。 * **16B**:`192 * 4096 / 192 = 4096` token/GPU * **70B**:`384 * 4096 / 768 = 2048` token/GPU * **314B**:`1536 * 4096 / 3072 = 2048` token/GPU 这意味着,除了第一个之外,后两个都在每卡 2k token 左右,明显接近我们前面为 FSDP 计算出的临界阈值。我们给出的阈值是每 GPU 2472 token,对应脊层归约,正好适用于这里。不过,对于 70B 和 314B,由于它们分别具有 16 路和 64 路模型并行(PP + TP),脊层上的有效吞吐又会分别提升 2 倍和 8 倍,因此它们应分别在大约 1k 和 300 token/step 左右就仍然处于计算受限区。

致谢与进一步阅读

本章在很大程度上受益于许多 GPU 专家的帮助,包括:

关于 GPU,有大量优秀读物。以下是我个人尤其喜欢的一些:

附录 A:GB200 会如何改变这些结论?

Blackwell 引入了大量重要的网络变化,包括 NVLink 5,其总 NVLink 带宽翻倍到 900GB/s。B200 仍然像 H100 一样是 8-GPU 节点,但 GB200 系统(把 B200 GPU 与 Grace CPU 组合在一起)则引入了更大的 NVLink 域(NVL72 中为 72 张 GPU,理论上最高可到 576)。更大的 NVLink 域也等效地提高了节点出口带宽,从而降低了节点以上层级的集合通信成本。

<b>图:</b>展示一个 GB200 NVL72 单元如何由 18 个交换机和 72 张 GPU 构成。
图:展示一个 GB200 NVL72 单元如何由 18 个交换机和 72 张 GPU 构成。

在节点内部,这种带宽提升(从 450GB/s 到 900GB/s)意义并不算太大,因为每张 GPU 的总 FLOPs/s 也同时翻倍了。我们的 roofline 大体保持不变,不过由于 NVLink 带宽更高,专家并行会变得更容易。

在节点之外,变化就大得多。下面是一个来自这里的 SuperPod 图。

<b>图:</b>展示一个拥有 576 张 GPU 的 GB200 DGX SuperPod。
图:展示一个拥有 576 张 GPU 的 GB200 DGX SuperPod。

正如你所见,单节点出口带宽提高到了 4 * 18 * 400 / 8 = 3.6TB/s,而 H100 上只有 400GB/s。由于每芯片 FLOPs 也只是大约翻倍,因此这会把有效的跨节点 roofline 改善约 4 倍。此时,我们甚至可能开始担心真正的瓶颈会落在节点层而不是 scale-out 层。

Grace Hopper: NVIDIA 还销售 GH200 和 GB200 系统,它们把若干 GPU 与 Grace CPU 配对在一起。例如,一个 GH200 包含 1 个 H200 和 1 个 Grace CPU,而一个 GB200 系统则包含 2 个 B200 和 1 个 Grace CPU。它的一个优点是,CPU 与 GPU 之间通过一条全带宽 NVLink 连接(称为 NVLink C2C),因此 CPU 到 GPU 带宽非常高,很适合把参数卸载到主机内存。换句话说,对于任意一张 GPU 来说,访问主机内存的带宽与访问另一张 GPU 的 HBM 是一样的。

附录 B:更多网络细节

下面是一张 NVLink 4 交换机的示意图。它总共有 64 个 NVLink4 端口(每个端口使用 2 条物理 lane),以及一个负责跨 lane 交换的大型 crossbar。相比之下,TPU 使用的是带镜面的光交换机,可进行动态重构。

<b>图:</b>单个 NVLink4 交换机的更底层结构示意图。
图:单个 NVLink4 交换机的更底层结构示意图。

在每一层,我们都可能受限于链路带宽,也可能受限于总交换机带宽。

换算到每张 GPU,这意味着节点内 GPU 到 GPU 带宽为 450GB/s,SU 层为 50GB/s,脊层为 25 GB/s。

GPU 经验 AllReduce 带宽:

<b>图:</b>8xH100 集群上的 AllReduce 带宽(节点内,SHARP 关闭)。
图:8xH100 集群上的 AllReduce 带宽(节点内,SHARP 关闭)。

TPU v5p 带宽(单轴):

<b>图:</b>TPU v5p 4x4x4 集群上沿单轴的 AllReduce 带宽。
图:TPU v5p 4x4x4 集群上沿单轴的 AllReduce 带宽。

下面是 AllGather 带宽:

<b>图:</b>8xH100 集群上的 AllGather 带宽(节点内)。
图:8xH100 集群上的 AllGather 带宽(节点内)。
<b>图:</b>TPU v5e 8x16 集群上沿单轴的 AllGather 带宽。
图:TPU v5e 8x16 集群上沿单轴的 AllGather 带宽。

更多关于 AllToAll 成本的说明:

这里我们可以比较近似式 $\min(K / Z) * (Z - 1) / Z$ 与更精确的值 $(1 - ((Z - 1) / Z) ** K) * (Z - 1) / Z$。除了 Z 较小时,两者都非常接近。

<b>图:</b>随着 shard 数量增加,不规则 AllToAll 的近似成本与真实成本对比。
图:随着 shard 数量增加,不规则 AllToAll 的近似成本与真实成本对比。