分布式训练技术 - 张量并行

上一篇文章我们学的是 DP 和 DDP,它们的思路是 用显存冗余换吞吐量。每张 GPU都有完整模型,但是只处理不同的数据,它的本质是复制模型 → 并行处理数据 → 最后通过 AllReduce 同步梯度,代价是模型被复制 N 份,占用 N 倍显存。张量并行 Tensor Parallel 的思路正好相反是 用通信换显存。现在的大模型参数量巨大一张卡很可能放不下,所以把模型拆到多卡,每张 GPU 只有部分模型,但是处理完整的数据,最后进行合并。
引言
上一篇提到的 Data Parallel 的核心机制是:
- 把一个 batch 里的样本拆分到多个 DP worker 上。
- 每个 worker 拿到 batch 里的一个 micro-batch,独立做 forward + backward。
- 最后 All-Reduce 同步梯度。
但这存在一个关键前提:必须存在一个 batch dimension 才能把样本切开。它要求输入张量的形状为 [bs, seq_len, dim], DP 切分的是第 0 维,每个 worker 拿到 [bs/dp_size, seq_len, dim]。
而我们在处理我们在处理 LLM 训练任务,通常是 SFT,样本序列是变长的。我们通常采用两种办法:
- Batching + Padding 模式:把一个 batch 内样本的序列都 padding 到最长,这样LLM的输入是 [bs, max_seq_len, hidden_size],然后通过 attention mask 对 PAD TOKEN 进行掩码
- Packing 模式:在 vLLM 中为了提高计算效率,我们去掉了 batch 维度,把所有 batch 的 sequence 都压成了一个长序列,通过一个张量记录每个 sequence 的结束位置,这样就不用对变长序列插入 PAD 浪费算力
对于 Packing 模型,我们去掉了 batch 维度就无法采用 Data Parallel 了,这一章就研究一下 Tensor Parallel 是如何实现的。
核心思想
张量并行的核心思想是将单层内的权重矩阵切分到多张 GPU 上,协同完成矩阵运算,切分的方式有两种:按行切分权重和按列切分权重。
按列切分权重
以一个线性层 $Y = XW$ 为例,如果 $W \in \mathbb{R}^{d \times d}$,可以将 $W$ 按列切分为两半:$W = [W_1, W_2]$,分别放在两张 GPU 上。每张 GPU 计算 $Y_i = XW_i$,得到输出的一半。最后拼接结果:$Y = [Y_1, Y_2]$。
按行切分权重
按行切分权重矩阵,我们需要把输入张量 $X$ 也按列切开:
对权重的梯度我们有:
对输入的梯度我们有:
这里的 $X$ 代表的不是输入的 input_ids,而是上一层 Decode Layer 传来的中间值,所以我们需要对图中的 $X$ 求偏导。
Embedding 层
Embeddings 的难点在于 weight 较大,需要拆分到多个设备上,并实现正确的lookup,下面以4张卡简述其实现步骤:
- 将 wte 较均等分布到多张卡上
- 将 input_ids 复制到所有卡上
- 在每一张卡上input_ids分别lookup 卡上的子wte
- 将所有卡上的值 all-reduce
MLP 层
在 MLP 里面我们采样对 $A$ 进行列切分,对 $B$ 进行行切分,为什么呢?
- 假设我们全部采用行切分,那么 $A$ 和 $B$ 需要两个 AllReduce,通信量很大。
- 假设都采用列切分,我们按照
X → A → Y → B → Z的流程。第一层我们把 $A$ 矩阵切分为 $A_1$ 和 $A_2$,得到 GPU1 上有 $Y_1=X\cdot A_1$,GPU2 上有 $Y_2=X\cdot A_2$,目前还是正常的。但是第二层就有问题了,此时 GPU1 上有 $Y_1$,GPU2 上有 $Y_2$,然后我们把 $B$ 矩阵按照列切分,GPU1 上有 $Z_1=Y_1\cdot B_1$,GPU2 上有 $Z_2=Y_2\cdot B_2$,他们各自少了 $Y_2$ 和 $Y_1$,每个 GPU 只算了一半的贡献。
假如我们先把 $A$ 列切分,把 $B$ 行切分,那么就有:经过第一次按列切分 GPU1 有 $A_1$ 计算得到 $Y_1$, GPU2 有 $A_2$ 计算得到 $Y_2$。然后第二次按行切分,GPU1 有 $B_1$ 和 $Y_1$ 计算得到 $Z$ 的部分贡献,GPU2 有 $B_2$ 和 $Y_2$ 计算得到 $Z$ 的部分贡献,通过一次 AllReduce 将两者加在一起得到了完整的 $Z$。
Attention 层
Self-Attention 的张量并行更简单,因为self-attention天然的是多头注意力机制,可以将每个头的计算分配到不同的 GPU 上。由于有多个头,可以考虑使用 head 的某个因子数(n)作为设备数,每张卡跑 $head//n$ 个头,那么问题就变成了如何拆分 weight 以及同步最终结果。
假设我们有 4 张 GPU,$W_q$、$W_k$、$W_v$ 矩阵可以被拆分为 4 块如下:
然后我们把输入 $x$ 传到每一个 GPU 和拆分的权重矩阵进行矩阵乘法得到部分的输出,最后拼接起来。
通讯量
- MLP、Attention:forward 和 backward 各一次 AllReduce,AllReduce 分为 Reduce-Scatter 和 All-Gather 两个阶段,总通讯量为 $4\Psi$。
- Embedding:forward 部分每个 GPU 只负责 vocab 的一部分,lookup 不需要通信。而 backward 需要对 embedding weight 做梯度聚合是一次 AllReduce,所以通信量为 $2\Psi$。
具体实现
Embedding
在介绍 Embedding 层之前说明一下 TP 里面出现的参数:
tp_size:总 GPU 数量tp_rank:当前所在 GPU 编号
class VocabParallelEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
super().__init__()
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
def forward(self, x: torch.Tensor):
if self.tp_size > 1:
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight)
if self.tp_size > 1:
y = mask.unsqueeze(1) * y
dist.all_reduce(y)
return y
__init__() 中获取了当前所在的 GPU 编号,并且对词表进行了划分。然后在前向计算过程中:
- 先计算掩码,把不在划分范围内的 token id 记为 False。
- 然后对 token id 进行一个 shift 操作,把 $[a_1,a_2,\ldots,a_n]$ 移到 $[0, 1, \ldots, n]$,再应用 mask。
- 得到 $y$ 之后还需要应用一下 mask,因为 mask 掉 token id 变成 0,0 对应的 embedding tensor 也要置为全零。
- 最后通过 AllReduce 传出去。
Lm_Head
Lm_Head 就是 Embedding 的一个逆过程
class ParallelLMHead(VocabParallelEmbedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
):
assert not bias
super().__init__(num_embeddings, embedding_dim)
def forward(self, x: torch.Tensor):
# x: [tokens, dim]
context = get_context()
if context.is_prefill:
last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight)
if self.tp_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
dist.gather(logits, all_logits, 0)
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits
这里需要注意一下,在推理框架中我们通过 Continuous Batching 把多条序列被直接拼接成一个扁平的 tensor:
seq1: [t1, t2, t3] →
seq2: [t4, t5] → [t1, t2, t3, t4, t5, t6, t7, t8] shape: [8, dim]
seq3: [t6, t7, t8] →
cu_seqlens_q 记录的就是边界,例如:[0, 3, 5, 8]。对于每条序列,我们只需要最后一个位置的 logits 来预测 next token,所以通过 cu_seqlens_q[1:] - 1 获得每个序列的最后一个 token 的位置。
最后把各个 GPU 计算汇合:
dist.gather将各卡的分片 logits 汇聚到 rank 0torch.cat(..., -1)在最后一维(词表维)拼接,还原完整[batch, vocab_size]- 只有 rank 0 持有完整 logits,其余 rank 返回
None,上层调用方需注意判空
Linear
class LinearBase(nn.Module):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
tp_dim: int | None = None,
):
super().__init__()
self.tp_dim = tp_dim
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
Linear 基类定义了一些初始参数,初始化了权重矩阵,然后 tp_dim 记录了这个 Linear 是列切分还是行切分。
class ColumnParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
Linear 按照列切分之后它的形状完整,但是数值不完整,所以不需要进行 AllReduce。
class RowParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
dist.all_reduce(y)
return y
行切分之后需要进行一次 AllReduce,所有节点的 y 值加起来得到最终结果。
class QKVParallelLinear(ColumnParallelLinear):
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = False,
):
tp_size = dist.get_world_size()
total_num_kv_heads = total_num_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
super().__init__(hidden_size, output_size, bias)
QKV 的计算方法就是把 vocab_size 投影到 3*hidden_size,然后在 dim 维度进行拆分得到 Q/K/V。
MLP
class Qwen3MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
MLP 的第一个 Linear 是列切分,计算后每个 GPU 会得到一部分的结果:
GPU0 得到 $Y_1$,GPU2 得到 $Y_2$。接着经过激活函数,会再次进行 Linear 的线性变换,这次是行切分。我们之前说过行切分需要对 $X$ 进行列切分,正好上一步计算每个 GPU 得到一半的 $Y$, GPU0 拿着 $Y_1$ 和 $W_{b1}$ 计算得到 $Y_1W_{b1}=Z_1$, GPU1 拿着 $Y_2$ 和 $W_{b2}$ 计算得到 $Y_2W_{b2}=Z_2$,最后 AllReduce 之后得到 $Z=Z_1+Z_2$。
import torch.distributed as dist,在程序运行之后会启动 tp_size 个 Python 进程。每个进程绑定一张 GPU,从 train.py 第一行开始执行。假设我们的模型只有 MLP,那么每个进程都会执行 model = Qwen3MLP(...)。每个 GPU 都会输入完整的 x,然后用切分过得权重矩阵对他进行计算,最终得到不完整的 y,然后通过 AllReduce 所有 GPU 都得到了完整的 y,它们就可以继续下去了。AllReduce 模拟
- 定义行切分 Linear
class RowParallelLinear(nn.Module):
"""
权重 W [H_out, H_in] 按输入特征维(行)切分:
rank k 持有 W[:, col_start:col_end],形状 [H_out, shard_size]
输入 X [B, H_in] 同样取对应特征列 X[:, col_start:col_end]
各 rank 计算 partial sum,最终 All-Reduce 求和
"""
def __init__(self, in_features: int, out_features: int,
tp_rank: int, tp_size: int, bias: bool = False):
super().__init__()
assert in_features % tp_size == 0, \
f"in_features={in_features} 必须能被 tp_size={tp_size} 整除"
self.in_features = in_features
self.out_features = out_features
self.tp_rank = tp_rank
self.tp_size = tp_size
self.shard_size = in_features // tp_size # 每个 rank 负责的输入特征列数
self.col_start = tp_rank * self.shard_size
self.col_end = self.col_start + self.shard_size
# 本 rank 只持有权重的一个列分片 [H_out, shard_size]
self.weight = nn.Parameter(torch.empty(out_features, self.shard_size))
nn.init.xavier_uniform_(self.weight)
# bias 只有 rank 0 持有,All-Reduce 后加,避免被重复累加 tp_size 次
self.bias = nn.Parameter(torch.zeros(out_features)) \
if (bias and tp_rank == 0) else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, H_in] 完整输入,取本 rank 负责的特征列
x_shard = x[:, self.col_start:self.col_end] # [B, shard_size]
y_partial = x_shard @ self.weight.T # [B, H_out]
return y_partial # 调用方负责 All-Reduce
- 随机初始化输出和权重矩阵
torch.manual_seed(42)
IN_FEATURES = 6 # 输入特征维度(必须能被 TP_SIZE 整除)
OUT_FEATURES = 4
TP_SIZE = 2
BATCH_SIZE = 3
X = torch.arange(BATCH_SIZE * IN_FEATURES, dtype=torch.float32).reshape(BATCH_SIZE, IN_FEATURES)
print("=" * 60)
print(f"模拟开始: TP_SIZE={TP_SIZE}")
print(f"输入 X shape: {list(X.shape)} [B={BATCH_SIZE}, H_in={IN_FEATURES}]")
print(f"输出期望 shape: [B={BATCH_SIZE}, H_out={OUT_FEATURES}]")
print(f"每个 Rank 负责的输入特征列数: {IN_FEATURES // TP_SIZE}")
print("=" * 60)
W_full = torch.arange(OUT_FEATURES * IN_FEATURES, dtype=torch.float32).reshape(OUT_FEATURES, IN_FEATURES) * 0.1
输出为:
============================================================
模拟开始: TP_SIZE=2
输入 X shape: [3, 6] [B=3, H_in=6]
输出期望 shape: [B=3, H_out=4]
每个 Rank 负责的输入特征列数: 3
============================================================
- 模拟 Tensor Parallel 过程
partial_outputs = []
for tp_rank in range(TP_SIZE):
print(f"\n{'#' * 20} 模拟 Rank {tp_rank} 的计算过程 {'#' * 20}")
layer = RowParallelLinear(IN_FEATURES, OUT_FEATURES, tp_rank=tp_rank, tp_size=TP_SIZE)
col_start = layer.col_start
col_end = layer.col_end
# 从完整权重里取本 rank 的列分片,写入层权重
W_shard = W_full[:, col_start:col_end] # [H_out, shard_size]
with torch.no_grad():
layer.weight.copy_(W_shard)
print(f"\nRank {tp_rank} 负责输入特征列范围: [{col_start}, {col_end})")
print(f"权重分片 W_shard shape: {list(W_shard.shape)}")
print(f"权重分片 W_shard:\n{W_shard}")
# 取对应输入列
X_shard = X[:, col_start:col_end]
print(f"\nX_shard (X[:, {col_start}:{col_end}]):\n{X_shard}")
# 局部矩阵乘法 → partial sum
Y_partial = layer(X) # 内部自动切 X_shard
print(f"\nY_partial = X_shard @ W_shard.T:\n{Y_partial}")
partial_outputs.append(Y_partial)
这里注意一下,Linear 里面权重矩阵的形状是相反的,所以我们对
W_shard进行列切分,计算时候转置就是行切分了。
输出为:
#################### 模拟 Rank 0 的计算过程 ####################
Rank 0 负责输入特征列范围: [0, 3)
权重分片 W_shard shape: [4, 3]
权重分片 W_shard:
tensor([[0.0000, 0.1000, 0.2000],
[0.6000, 0.7000, 0.8000],
[1.2000, 1.3000, 1.4000],
[1.8000, 1.9000, 2.0000]])
X_shard (X[:, 0:3]):
tensor([[ 0., 1., 2.],
[ 6., 7., 8.],
[12., 13., 14.]])
Y_partial = X_shard @ W_shard.T:
tensor([[ 0.5000, 2.3000, 4.1000, 5.9000],
[ 2.3000, 14.9000, 27.5000, 40.1000],
[ 4.1000, 27.5000, 50.9000, 74.3000]], grad_fn=<MmBackward0>)
#################### 模拟 Rank 1 的计算过程 ####################
Rank 1 负责输入特征列范围: [3, 6)
权重分片 W_shard shape: [4, 3]
权重分片 W_shard:
tensor([[0.3000, 0.4000, 0.5000],
[0.9000, 1.0000, 1.1000],
[1.5000, 1.6000, 1.7000],
[2.1000, 2.2000, 2.3000]])
X_shard (X[:, 3:6]):
tensor([[ 3., 4., 5.],
[ 9., 10., 11.],
[15., 16., 17.]])
Y_partial = X_shard @ W_shard.T:
tensor([[ 5.0000, 12.2000, 19.4000, 26.6000],
[ 12.2000, 30.2000, 48.2000, 66.2000],
[ 19.4000, 48.2000, 77.0000, 105.8000]], grad_fn=<MmBackward0>)
- AllReduce 聚合
print(f"\n\n{'=' * 25} 模拟 All-Reduce 聚合 {'=' * 25}")
for i, p_out in enumerate(partial_outputs):
print(f"\n来自 Rank {i} 的 partial sum:\n{p_out}")
final_output = torch.stack(partial_outputs).sum(dim=0)
print(f"\n聚合后的最终结果 (sum of all partial outputs):\n{final_output}")
输出为:
========================= 模拟 All-Reduce 聚合 =========================
来自 Rank 0 的 partial sum:
tensor([[ 0.5000, 2.3000, 4.1000, 5.9000],
[ 2.3000, 14.9000, 27.5000, 40.1000],
[ 4.1000, 27.5000, 50.9000, 74.3000]], grad_fn=<MmBackward0>)
来自 Rank 1 的 partial sum:
tensor([[ 5.0000, 12.2000, 19.4000, 26.6000],
[ 12.2000, 30.2000, 48.2000, 66.2000],
[ 19.4000, 48.2000, 77.0000, 105.8000]], grad_fn=<MmBackward0>)
聚合后的最终结果 (sum of all partial outputs):
tensor([[ 5.5000, 14.5000, 23.5000, 32.5000],
[ 14.5000, 45.1000, 75.7000, 106.3000],
[ 23.5000, 75.7000, 127.9000, 180.1000]], grad_fn=<SumBackward1>)
- 验证结果
print(f"\n\n{'=' * 28} 验证结果 {'=' * 28}")
print(f"\n完整权重矩阵 W_full:\n{W_full}")
ref = nn.Linear(IN_FEATURES, OUT_FEATURES, bias=False)
with torch.no_grad():
ref.weight.copy_(W_full)
ref_output = ref(X)
print(f"\n标准 nn.Linear 计算结果:\n{ref_output}")
# 逐样本对比
print(f"\n--- 逐样本对比 ---")
for b in range(BATCH_SIZE):
tp_vec = final_output[b]
ref_vec = ref_output[b]
match = torch.allclose(tp_vec, ref_vec, atol=1e-5)
print(f"样本 {b}: TP={tp_vec.tolist()} REF={ref_vec.tolist()} {'✓' if match else '✗'}")
are_equal = torch.allclose(final_output, ref_output, atol=1e-5)
print(f"\n并行计算结果与标准 nn.Linear 结果是否一致: {are_equal}")
输出为:
============================ 验证结果 ============================
完整权重矩阵 W_full:
tensor([[0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000],
[0.6000, 0.7000, 0.8000, 0.9000, 1.0000, 1.1000],
[1.2000, 1.3000, 1.4000, 1.5000, 1.6000, 1.7000],
[1.8000, 1.9000, 2.0000, 2.1000, 2.2000, 2.3000]])
标准 nn.Linear 计算结果:
tensor([[ 5.5000, 14.5000, 23.5000, 32.5000],
[ 14.5000, 45.1000, 75.7000, 106.3000],
[ 23.5000, 75.7000, 127.9000, 180.1000]], grad_fn=<MmBackward0>)
--- 逐样本对比 ---
样本 0: TP=[5.5, 14.5, 23.5, 32.5] REF=[5.5, 14.5, 23.5, 32.5] ✓
样本 1: TP=[14.5, 45.099998474121094, 75.69999694824219, 106.29999542236328] REF=[14.5, 45.099998474121094, 75.69999694824219, 106.30000305175781] ✓
样本 2: TP=[23.5, 75.69999694824219, 127.9000015258789, 180.10000610351562] REF=[23.5, 75.69999694824219, 127.9000015258789, 180.10000610351562] ✓
并行计算结果与标准 nn.Linear 结果是否一致: True
