到底需要多少算力才能部署一个模型,这是一个非常常见的问题。我们就从训练和推理两个场景,分析一下如何估计模型所需要的显存。
训练显存大致分为以下四部分:
- 模型权重:取决于存储的精度,常见的 BF16 和 FP16 占用大小为 2B
- 梯度:反向传播计算的梯度,和权重一样常见情况下占用 2B
- 优化器状态:常见的 Adam 会为每个参数都保存它的 Momentum、Variance 和 Master weights,精度为 FP32 所以总计 12B
- 中间激活值:简单来说就是为了计算反向传播的梯度,需要把前向计算的中间值存储起来,具体计算见下文。
因此使用 AdamW 优化器 + 混合精度训练的经验公式为:
VRAMtrain≈20×N(Bytes)
以 70B 的模型为例,训练显存需要 $\approx 70 \times 10^9 \times 20 \text{Bytes} = 1400 \text{GB}$ ,80GB 的 A100 需要至少 18 张并行训练。
推理只需要存模型的权重(不考虑 KVCache 的情况):
VRAMinfer≈2×N(Bytes)如果是 4-bit 量化:
VRAMinfer≈0.5×N(Bytes)考虑 KVCache 加速推理的情况下,如果精度为 FP16,那么额外需要:
VRAMKVCache≈b×s×l×h×2(Bytes)
在长序列推理中, KVCache 占据显存非常大。示例:LLaMA-7B (hidden_size=4096, layers=32),batch=1, seq_len=32k → KV Cache ≈ 32k × 32 × 2 × 4096 × 2 ≈ 8 GB。 seq_len=128k 的情况下仅仅 KV Cache 会到 32 GB+。
这里的激活值和激活函数没有啥关系,以一个四个 Linear 的模型结构为例进行说明。其前向传播和损失函数的公式如下所示:
x1x2x3x4l=W1x+b1=W2x1+b2=W3x2+b3=W4x3+b4=(y−x4)2
在该公式中:$x$ 和 $y$ 为数据的特征和标签;$W_1$、$b_1$、$W_2$、$b_2$、$W_3$、$b_3$、$W_4$、$b_4$ 为四个 Linear 层的权重和偏置;$x_1$、x$_2$、$x_3$、$x_4$ 都是计算过程中的中间状态。反向传播过程中要对权重进行更新,也就是求损失相对于 $W_1$、$W_2$、$W_3$、$W_4$ 的偏导,按照链式求导法则得到公式如下:
∂W4∂l∂W3∂l∂W2∂l∂W1∂l=∂x4∂l⋅∂W4∂x4=[−2(y−x4)]⋅x3=∂x4∂l⋅∂x3∂x4⋅∂W3∂x3=[[−2(y−x4)]⋅W4]⋅x2=∂x4∂l⋅∂x3∂x4⋅∂x2∂x3⋅∂W2∂x2=[[−2(y−x4)]⋅W4⋅W3]⋅x1=∂x4∂l⋅∂x3∂x4⋅∂x2∂x3⋅∂x1∂x2⋅∂W1∂x1=[[−2(y−x4)]⋅W4⋅W3⋅W2]⋅x对上面这四个权重矩阵的链式求导公式找一下规律,可以发现对于权重矩阵 $W_i$ 的梯度在计算时主要有两项:
- 第一项是上述公式中使用特别大的中括号扩起来的部分,这部分是第 i+1 层反传回来的值,我们使用符号 $i+1$ 来表示这一项;
- 另一项则是第 $i−1$ 层计算出来的中间值,使用符号 $x_{i−1}$ 来表示;
那么对于 $W_i$ 的梯度计算公式就变为了 $\frac{\partial l}{\partial W_i} = l_{i+1} \cdot x_{i-1}$,这里的 $l_{i+1}$ 是第 $i+1$ 层反传过来的,所以计算第 $i$ 层的梯度时只需要做一次矩阵乘法即可。这里的 $x_{i−1}$ 正是在前向传播时计算出来的中间状态,比较官方的术语为 中间激活值。
这里把 transformer 层分为两部分,一部分是 MHA 层,一部分是 FFN 层。下面分别写一下这两部分的公式。一般的资料中关于 transformer 的公式仅写主要的部分,像dropout、normalize、激活函数都会被省略,但是这里由于需要分析中间激活值的显存,所以会把整个 transformer 的所有操作都体现到公式中,如下。
MHA 层的公式如下:
Qxselfxattn=x⋅WQ,K=x⋅Wk,V=x⋅Wv=Dropout[softmax(dQ⋅KT)]⋅V=LN[Dropout(xself⋅wo)+x]FFN 层的公式如下:
xffnxo=GeLU(xattn⋅Wff1)⋅Wff2=LN[Dropout(xffn)+xattn]总的来说,MHA 层的输入为 $x$,输出为 $x_{attn}$;FFN 层的输入为 $x_{attn}$,输出为 $x_o$;
首先定义几个符号:
- b:表示batch_size;
- s:表示seq_length,为文本长度;
- h:表示hidden_dim,为隐藏层的维度;
- a:表示多头注意力中有多个头;
- ha:表示hidden_dim_per_head,为多头注意力中每个头的隐藏层维度;
另外,在实际使用时一般都有 ha∗a=h 成立。
MHA 层需要保存的激活值,以及每个激活值的大小:
Q=x⋅WQK=x⋅WkV=x⋅WvQ⋅KTsoftmax(dQTK)Dropout[softmax(dQ⋅KT)]xself=Dropout[softmax(dQ⋅KT)]⋅Vxself⋅WoDropout(xself⋅wo)xattn=LN[Dropout(xself⋅wo)+x]:维度为 [b,a,s,ha]=[b,s,h],:维度为 [b,a,s,ha]=[b,s,h],:维度为 [b,a,s,ha]=[b,s,h],:维度为 [b,a,s,s],:维度为 [b,a,s,s],:维度为 [b,a,s,s],:维度为 [b,a,s,ha]=[b,s,h],:维度为 [b,s,h],:维度为 [b,s,h],:维度为 [b,s,h],大小为 2bsh 字节大小为 2bsh 字节大小为 2bsh 字节大小为 2bas2 字节大小为 2bas2 字节Dropout 层大小为 bas2 字节大小为 2bsh 字节大小为 2bsh 字节Dropout 层大小为 bsh 字节大小为 2bsh 字节FFN 层需要保存的激活值,以及每个激活值的大小:
xattn⋅Wff1GeLU(xattn⋅Wff1)xffn=GeLU(xattn⋅Wff1)⋅Wff2Dropout(xffn)LN[Dropout(xffn)+xattn]:维度为 [b,s,4h],:维度为 [b,s,4h],:维度为 [b,s,h],:维度为 [b,s,h],:维度为 [b,s,h],大小为 8bsh 字节大小为 8bsh 字节大小为 2bsh 字节Dropout 层大小为 bsh 字节大小为 2bsh 字节将 MHA 和 FFN 层全部加起来得到:
2bsh+2bsh+2bsh+2bas2+2bas2+bas2+2bsh+2bsh+bsh+2bsh+8bsh+8bsh+2bsh+bsh+2bsh=34bsh+5bas2如果有 $l$ 层 transformer,那么这 $l$ 层 transformer 总的中间激活值占用的显存为:$l∗(34bsh+5bas^2)$
上面仅分析了多个 transformer 对应的中间激活值消耗的显存的大小。模型中还会有 embedding 层和解码层。其中解码层没有对应的中间激活值,只需要分析一下 embedding 层即可。
embedding 层的功能是将输入的 token ID 转为向量,其输出的矩阵维度为 [batch_size, seq_length, hidden_size],即 [b, s, h],该中间激活值占用的显存为 2bsh。
综上所述,整个模型所有的中间激活值的大小为$l∗(34bsh+5bas^2)+2bsh$。随着模型越来越大,$l$ 是比较大的,所以有时会忽略 $2bsh$ 这一项,直接使用 $l∗(34bsh+5bas^2)$ 来估计模型的中间激活值的大小。