🚀 PyTorch 并行训练指南

一、三种并行方式概述

1. 数据并行 (Data Parallelism)

GPU0: 完整模型 + 数据批次1 梯度1 ─┐ GPU1: 完整模型 + 数据批次2 梯度2 ─┼→ 聚合 更新 GPU2: 完整模型 + 数据批次3 梯度3 ─┘

2. 模型并行 / 张量并行 (Model/Tensor Parallelism)

一个大型线性层 W: GPU0: W[:, 0:n/2] ─┐ ├→ 合并输出 GPU1: W[:, n/2:n] ─┘

3. 流水线并行 (Pipeline Parallelism)

GPU0: 层1-4 GPU1: 层5-8 GPU2: 层9-12 ↑ ↑ ↑ micro-batch流水线执行

对比总结

并行方式 切分对象 通信开销 典型框架支持
数据并行 数据 中(梯度同步) PyTorch DDP, Horovod
张量并行 层内参数 高(每层通信) Megatron-LM
流水线并行 层间 低(仅层边界) GPipe, PipeDream

💡 大规模训练(如GPT、LLaMA)通常三者混合使用,称为 3D并行

训练与推理适用性

并行方式 训练 推理 说明
数据并行 ✅ 主要用途 ⚠️ 不常用 训练时同步梯度;推理无需梯度,通常用多实例替代
张量并行 ✅ 常用 ✅ 常用 单层参数过大时必需,vLLM/TensorRT-LLM 等推理框架主要方式
流水线并行 ✅ 常用 ✅ 支持 模型层数多时使用,训练用micro-batch减少气泡

📌 推理场景总结:大模型推理框架(vLLM、TensorRT-LLM、HF TGI)主要使用张量并行流水线并行,数据并行在推理中通常被简单的多实例部署替代。

二、PyTorch 原生支持

1. 数据并行 - DistributedDataParallel (DDP)

推荐使用
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

dist.init_process_group(backend="nccl")
model = DDP(model, device_ids=[local_rank])

注意 旧版 DataParallel 已不推荐(单进程多线程,效率低)

2. 张量并行 - torch.distributed.tensor.parallel

PyTorch 2.0+
from torch.distributed.tensor.parallel import (
    parallelize_module, ColwiseParallel, RowwiseParallel
)

parallelize_module(
    model,
    device_mesh,
    {
        "layer.weight": ColwiseParallel(),
        "output.weight": RowwiseParallel(),
    }
)

3. 流水线并行 - torch.distributed.pipelining

PyTorch 2.0+
from torch.distributed.pipelining import (
    pipeline, SplitPoint, ScheduleGPipe
)

pipe = pipeline(
    model,
    mb_args=(microbatch,),
    split_spec={
        "layer4": SplitPoint.END,
        "layer8": SplitPoint.END,
    }
)
schedule = ScheduleGPipe(pipe, n_microbatches=4)
schedule.step(x)

3D 并行整合 - DeviceMesh

from torch.distributed.device_mesh import init_device_mesh

# 创建 2D mesh: (数据并行, 张量并行)
mesh = init_device_mesh("cuda", (2, 4), mesh_dim_names=("dp", "tp"))

更高层封装工具

工具说明
torch.distributed.fsdp分片数据并行(ZeRO风格)
torchtitan官方大模型训练参考实现
PyTorch Lightning简化多GPU训练配置

三、GPU 分配方法

1. 环境变量(推荐)

# 只使用 GPU 0 和 2
CUDA_VISIBLE_DEVICES=0,2 python train.py

# 程序内看到的是 cuda:0 和 cuda:1(重新编号)

2. 代码中指定设备

# 指定默认设备
torch.cuda.set_device(0)

# 创建张量时指定
tensor = torch.randn(3, 3, device="cuda:1")

# 模型移动到指定GPU
model = model.to("cuda:2")
# 或
model = model.cuda(2)

3. 分布式训练 (DDP)

启动方式:

# torchrun 自动分配 (推荐)
torchrun --nproc_per_node=4 train.py

# 指定使用哪些 GPU
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train.py

代码中获取分配的 GPU:

import os
import torch.distributed as dist

dist.init_process_group(backend="nccl")

local_rank = int(os.environ["LOCAL_RANK"])  # torchrun 自动设置
torch.cuda.set_device(local_rank)

model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank])

4. 多机分布式

# 机器1 (master)
torchrun --nnodes=2 --node_rank=0 --nproc_per_node=4 \
         --master_addr=192.168.1.1 --master_port=29500 train.py

# 机器2
torchrun --nnodes=2 --node_rank=1 --nproc_per_node=4 \
         --master_addr=192.168.1.1 --master_port=29500 train.py

5. DeviceMesh(3D 并行)

from torch.distributed.device_mesh import init_device_mesh

# 8卡: 2路数据并行 × 4路张量并行
mesh = init_device_mesh("cuda", (2, 4), mesh_dim_names=("dp", "tp"))

# 获取当前进程在各维度的位置
dp_rank = mesh.get_local_rank("dp")  # 0 或 1
tp_rank = mesh.get_local_rank("tp")  # 0, 1, 2, 或 3

常用查询命令

torch.cuda.device_count()        # 可用GPU数量
torch.cuda.current_device()      # 当前默认GPU
torch.cuda.get_device_name(0)    # GPU型号
torch.cuda.memory_summary(0)     # 显存使用情况

nvidia-smi                       # 查看所有GPU状态

总结

场景方法
单卡训练CUDA_VISIBLE_DEVICES=0.to("cuda:0")
单机多卡 DDPtorchrun --nproc_per_node=N
多机多卡torchrun --nnodes=M --nproc_per_node=N
混合并行DeviceMesh 定义拓扑

四、1024 GPU 3D并行实战案例

场景设定

GPU 显存规格

GPU 型号 显存 BF16 算力 显存带宽
A100 SXM 40GB / 80GB 312 TFLOPS 2.0 TB/s
H100 SXM 80GB 989 TFLOPS 3.35 TB/s
H100 NVL 94GB 989 TFLOPS 3.9 TB/s

LLaMA-70B 显存需求分析

组成部分 计算公式 显存占用
模型参数 (BF16) 70B × 2 bytes 140 GB
梯度 (BF16) 70B × 2 bytes 140 GB
优化器状态 (AdamW FP32) 70B × 4 bytes × 3 (参数+m+v) 840 GB
激活值 (带检查点) 依赖 batch size 和序列长度 ~50-100 GB
总计(单卡无并行) - ~1170 GB

3D 并行后每卡显存

并行策略 切分效果 每卡显存
TP=8 (张量并行) 参数/梯度/优化器 ÷ 8 ~18-25 GB
80GB 显存绰绰有余
PP=8 (流水线并行) 仅存 1/8 的层
FSDP (可选) 参数/梯度/优化器再分片

💡 计算:1170 GB ÷ 64 (TP×PP) ≈ 18 GB/卡 + 激活值 ≈ 20-25 GB/卡

训练数据量参考

模型版本 训练数据量 来源
LLaMA-1 70B 1.4 T tokens Meta 论文 (2023.02)
LLaMA-2 70B 2.0 T tokens Meta 论文 (2023.07)
LLaMA-3 70B 15 T tokens Meta 论文 (2024.04)

训练时间估算

理论计算公式:

训练时间 = 总 FLOPs / (GPU数 × 单卡算力 × MFU效率) 其中: • 总 FLOPs ≈ 6 × 模型参数 × tokens数 (前向2倍 + 反向4倍) • MFU (Model FLOPs Utilization) ≈ 35-45% (考虑通信、IO等开销)

实际估算示例

配置项 数值
模型参数 70B = 7×10¹⁰
训练数据 2T tokens = 2×10¹²
总 FLOPs 6 × 7×10¹⁰ × 2×10¹² = 8.4×10²³
GPU 集群 1024 × H100 (989 TFLOPS)
MFU 效率 40%
有效算力 1024 × 989×10¹² × 0.4 = 4.05×10¹⁷ FLOPS
估算训练时间 8.4×10²³ ÷ 4.05×10¹⁷ ≈ 24 天

不同配置对比

GPU 配置 2T tokens 15T tokens
256 × A100 80GB ~90 天 ~675 天
1024 × A100 80GB ~23 天 ~170 天
1024 × H100 80GB ~7 天 ~54 天
2048 × H100 80GB ~3.5 天 ~27 天

⚠️ 注意:实际训练时间受集群网络、存储IO、故障恢复等因素影响,通常比理论值高 20-50%

⚡ 显存约束 vs 算力约束

大模型训练存在两个不同的瓶颈,需要区分对待:

约束类型 本质问题 优化技术 影响
显存约束 模型能否放入GPU 3D并行、ZeRO、Flash Attention、梯度检查点、混合精度 决定 "能不能训练"
算力约束 训练速度多快 增加GPU数量、使用更快GPU (H100)、提高MFU、优化通信 决定 "训练多快"

🧠 显存优化 → 解决 "能不能"

  • 70B模型原始需要 ~1170GB 显存
  • 3D并行 → 降到 ~18-25GB/卡
  • + ZeRO-2/3 → 降到 ~8-15GB/卡
  • + Flash Attention → 进一步节省激活
  • 结论:显存问题已被技术解决

⚡ 算力优化 → 解决 "快不快"

  • 总FLOPs = 6 × 70B × 2T = 8.4×10²³
  • 这是固定的计算量,无法减少
  • 唯一方法:增加有效算力
  • 有效算力 = GPU数 × 单卡FLOPS × MFU
  • 结论:训练时间由算力决定
训练时间计算公式: 6 × 参数量 × Tokens数 训练时间 = ───────────────────────────── GPU数 × 单卡FLOPS × MFU ↑ 固定的工作量 ↑ 可扩展的算力 关键结论: • 显存不足 → 用并行/ZeRO/Flash技术解决,与训练时间无关 • 训练太慢 → 只能增加GPU数量或使用更快的GPU(如H100替代A100) • MFU效率 → 优化通信、减少bubble、Overlap计算与通信

总结:3D并行+ZeRO+Flash解决的是"能否训练"的问题,训练速度完全取决于集群总算力(GPU数量×单卡算力×利用率)

3D 并行配置方案

并行维度 规模 切分策略 原因
张量并行 (TP) 8 单节点内 8 卡 节点内 NVLink 带宽高,适合频繁通信
流水线并行 (PP) 8 80层 ÷ 8 = 每段10层 跨节点通信少,只在层边界传输
数据并行 (DP) 16 16组独立数据流 线性扩展吞吐量

验证:TP × PP × DP = 8 × 8 × 16 = 1024 GPU

GPU 拓扑结构图

1024 GPU 3D 并行拓扑
数据并行 (DP=16)
DP组 0
64 GPU
DP组 1
64 GPU
DP组 2
64 GPU
···
DP组 15
64 GPU
流水线并行 (PP=8) - 每个DP组内部
Stage 0
层 1-10
8 GPU · 节点0
Stage 1
层 11-20
8 GPU · 节点1
Stage 2
层 21-30
8 GPU · 节点2
Stage 3
层 31-40
8 GPU · 节点3
Stage 4
层 41-50
8 GPU · 节点4
Stage 5
层 51-60
8 GPU · 节点5
Stage 6
层 61-70
8 GPU · 节点6
Stage 7
层 71-80
8 GPU · 节点7
张量并行 (TP=8) - 每个Stage内部(单节点)
节点内 8 GPU(NVLink 互联)
GPU 0
W[:,0:n/8]
GPU 1
W[:,n/8:2n/8]
GPU 2
W[:,2n/8:3n/8]
GPU 3
W[:,3n/8:4n/8]
GPU 4
W[:,4n/8:5n/8]
GPU 5
W[:,5n/8:6n/8]
GPU 6
W[:,6n/8:7n/8]
GPU 7
W[:,7n/8:n]
AllReduce 同步

启动脚本示例

#!/bin/bash
# run_1024gpu.sh - 在 SLURM 集群上启动训练

# SLURM 配置
#SBATCH --job-name=llama70b_train
#SBATCH --nodes=128
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8
#SBATCH --cpus-per-task=12

# 网络配置
export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export MASTER_PORT=29500
export WORLD_SIZE=1024

# 3D 并行配置
export TP_SIZE=8    # 张量并行度
export PP_SIZE=8    # 流水线并行度
export DP_SIZE=16   # 数据并行度(自动计算:1024/8/8=16)

# 训练超参数
export MICRO_BATCH_SIZE=1
export GLOBAL_BATCH_SIZE=2048  # DP_SIZE × gradient_acc_steps × micro_batch
export GRADIENT_ACC_STEPS=128  # 2048 / 16 / 1 = 128

# 启动训练
srun torchrun \
    --nnodes=128 \
    --nproc_per_node=8 \
    --rdzv_id=$SLURM_JOB_ID \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
    train.py \
    --model llama-70b \
    --tp_size $TP_SIZE \
    --pp_size $PP_SIZE \
    --micro_batch_size $MICRO_BATCH_SIZE \
    --global_batch_size $GLOBAL_BATCH_SIZE

训练代码框架

# train.py - 3D并行训练主程序
import os
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
    parallelize_module, ColwiseParallel, RowwiseParallel
)
from torch.distributed.pipelining import pipeline, SplitPoint, ScheduleGPipe
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

def main():
    # 1. 初始化分布式环境
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    world_rank = int(os.environ["RANK"])
    torch.cuda.set_device(local_rank)

    # 2. 创建 3D 设备网格 (DP, PP, TP)
    #    1024 GPU = 16 × 8 × 8
    mesh_3d = init_device_mesh(
        "cuda",
        (16, 8, 8),  # (DP, PP, TP)
        mesh_dim_names=("dp", "pp", "tp")
    )

    # 获取当前进程在各维度的位置
    dp_rank = mesh_3d.get_local_rank("dp")   # 0-15
    pp_rank = mesh_3d.get_local_rank("pp")   # 0-7 (哪个Stage)
    tp_rank = mesh_3d.get_local_rank("tp")   # 0-7

    # 3. 构建模型(仅当前PP Stage的层)
    layers_per_stage = 10  # 80层 / 8 stages
    start_layer = pp_rank * layers_per_stage
    end_layer = start_layer + layers_per_stage

    model = build_llama_layers(
        start_layer=start_layer,
        end_layer=end_layer,
        hidden_size=8192,
        num_heads=64
    )

    # 4. 应用张量并行(切分 Attention 和 FFN)
    tp_mesh = mesh_3d["tp"]
    parallelize_module(
        model,
        tp_mesh,
        {
            # Attention: Q/K/V 列切分,Output 行切分
            "attention.q_proj": ColwiseParallel(),
            "attention.k_proj": ColwiseParallel(),
            "attention.v_proj": ColwiseParallel(),
            "attention.o_proj": RowwiseParallel(),
            # FFN: gate/up 列切分,down 行切分
            "ffn.gate_proj": ColwiseParallel(),
            "ffn.up_proj": ColwiseParallel(),
            "ffn.down_proj": RowwiseParallel(),
        }
    )

    # 5. 设置流水线并行
    pp_group = mesh_3d.get_group("pp")

    # 6. 数据并行(使用FSDP进一步节省显存)
    dp_mesh = mesh_3d["dp"]
    model = FSDP(model, process_group=dp_mesh.get_group())

    # 7. 优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    # 8. 训练循环
    for step, batch in enumerate(dataloader):
        # 流水线执行(自动处理micro-batch)
        with pipeline_context(pp_group, num_microbatches=80):
            loss = model(batch)
            loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        if world_rank == 0 and step % 10 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")

if __name__ == "__main__":
    main()

通信与计算分析

并行类型 通信操作 频率 通信组规模
张量并行 AllReduce (每层2次) 每个前向/反向 8 GPU(节点内)
流水线并行 P2P Send/Recv 每个micro-batch 2 GPU(相邻Stage)
数据并行 AllReduce (梯度) 每个优化步 16 GPU(跨节点)
通信优先级(由高到低): 1. 张量并行 → 节点内 NVLink (900 GB/s) 2. 流水线并行 → 相邻节点 (仅激活值) 3. 数据并行 → 全局 AllReduce (可与计算重叠)

GPU 互联带宽详解

1024 GPU 集群中,GPU 间带宽取决于物理连接层级

GPU 互联带宽层级结构
节点内
同一台服务器
NVLink
A100: 600 GB/s
H100: 900 GB/s
张量并行
TP=8
节点间
跨服务器
InfiniBand
HDR: 200 Gb/s (25 GB/s)
NDR: 400 Gb/s (50 GB/s)
流水线/数据并行
PP + DP
优化技术
跨节点加速
GPUDirect RDMA - GPU直接读写远程GPU显存
NVSwitch - 节点内全互联

带宽规格对比

连接类型 技术 单向带宽 双向带宽 延迟
节点内 (A100) NVLink 3.0 (12 links) 300 GB/s 600 GB/s ~1 μs
节点内 (H100) NVLink 4.0 (18 links) 450 GB/s 900 GB/s ~1 μs
节点间 (HDR) InfiniBand HDR × 8 200 GB/s 400 GB/s ~1-2 μs
节点间 (NDR) InfiniBand NDR × 8 400 GB/s 800 GB/s ~1-2 μs
以太网 (对比) 100GbE × 8 100 GB/s 200 GB/s ~10-50 μs

1024 GPU (128节点) 实际带宽配置

层级 范围 互联技术 带宽 用途
节点内 8 GPU NVLink 4.0 + NVSwitch 900 GB/s (H100)
600 GB/s (A100)
张量并行 AllReduce
高频,每层2次
节点间 128 节点 InfiniBand NDR × 8 400 GB/s
400Gb/s × 8 = 3.2Tb/s
流水线 P2P + 数据并行
每micro-batch / 每优化步
网络拓扑 全集群 Fat-tree / Dragonfly 51.2 TB/s 总二分带宽
128节点 × 400GB/s ÷ 2
非阻塞 = 每节点满速400GB/s
3层 Spine-Leaf 架构
确保跨节点带宽

带宽与并行策略匹配

并行类型 通信模式 数据量/次 适合带宽
张量并行 AllReduce (每层2次) ~2 × hidden_dim² = ~130 MB 需要 NVLink (900 GB/s)
流水线并行 P2P Send/Recv ~batch × seq × hidden = ~16 MB InfiniBand 足够
数据并行 AllReduce (梯度) ~模型大小/TP/PP = ~1 GB 可与计算重叠

💡 关键设计原则:高频通信(张量并行)放节点内用 NVLink;低频大数据量通信(数据并行)放节点间可重叠计算

ZeRO 与 Flash Attention 优化

这两项技术可以与 3D 并行同时使用,进一步提升训练效率:

ZeRO (Zero Redundancy Optimizer)

ZeRO 各阶段显存分布对比(4 GPU 示例)
DDP
(ZeRO-0)
GPU 0
参数✓ 梯度✓ 优化器✓
GPU 1
参数✓ 梯度✓ 优化器✓
GPU 2
参数✓ 梯度✓ 优化器✓
GPU 3
参数✓ 梯度✓ 优化器✓
4× 冗余
ZeRO-1
GPU 0
参数✓ 梯度✓ Opt 1/4
GPU 1
参数✓ 梯度✓ Opt 2/4
GPU 2
参数✓ 梯度✓ Opt 3/4
GPU 3
参数✓ 梯度✓ Opt 4/4
省 ~4×
ZeRO-2
GPU 0
参数✓ Grad 1/4 Opt 1/4
GPU 1
参数✓ Grad 2/4 Opt 2/4
GPU 2
参数✓ Grad 3/4 Opt 3/4
GPU 3
参数✓ Grad 4/4 Opt 4/4
省 ~8×
ZeRO-3
Param 1/4
Grad 1/4 Opt 1/4
Param 2/4
Grad 2/4 Opt 2/4
Param 3/4
Grad 3/4 Opt 3/4
Param 4/4
Grad 4/4 Opt 4/4
省 ~N×
参数 (2B/param) 梯度 (2B/param) 优化器状态 (12B/param for AdamW)
ZeRO 阶段 切分内容 显存节省 通信开销
ZeRO-1 优化器状态 ~4× 无额外开销
ZeRO-2 优化器状态 + 梯度 ~8× 轻微增加
ZeRO-3 优化器状态 + 梯度 + 参数 ~线性 显著增加

PyTorch 原生实现 - FSDP:

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy
)

# ZeRO-3 等效配置
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # ZeRO-3
    # ShardingStrategy.SHARD_GRAD_OP  # ZeRO-2
    # ShardingStrategy.NO_SHARD       # DDP (ZeRO-0)
)

# 或使用 DeepSpeed
import deepspeed
model, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config={
        "zero_optimization": {
            "stage": 3,  # ZeRO-3
            "offload_optimizer": {"device": "cpu"},  # 可选:卸载到CPU
        }
    }
)

Flash Attention

优化 Attention 计算,减少显存 5-20×,加速 2-4×

标准 Attention vs Flash Attention 对比
❌ 标准 Attention
Q (N×d)
K (N×d)
V (N×d)
↓ Q × Kᵀ
完整 N×N 矩阵
存储在 HBM 显存
O(N²) 显存
↓ Softmax
完整 N×N 矩阵
再次存储
↓ × V
Output (N×d)
⚠️ 大量 HBM 读写,GPU 空闲等待
✅ Flash Attention
Q (N×d)
K (N×d)
V (N×d)
↓ 分块加载到 SRAM
Block 1
Qᵢ × Kⱼᵀ
Block 2
Qᵢ × Kⱼᵀ
Block 3
Qᵢ × Kⱼᵀ
...
🔄 在 SRAM 中完成
Softmax + Rescale + 累加
只保留 running max & sum
↓ 直接输出
Output (N×d)
✅ O(N) 显存,减少 HBM 访问
GPU 存储层级详解
寄存器
~20MB/SM
最快
共享内存/L1
228KB/SM
~19TB/s
L2 缓存
50MB
~12TB/s
HBM (显存)
80GB
~3TB/s
🎯 Flash Attention 使用的是:共享内存 (Shared Memory / SMEM)
  • 不是 L1/L2 缓存 — 缓存由硬件自动管理,无法精确控制
  • 是共享内存 (SMEM) — 由程序员显式管理,可精确控制数据布局
  • NVIDIA GPU 中,L1 缓存和共享内存共用同一块 SRAM,可配置分配比例
  • H100 每个 SM 有 228KB SRAM,共 132 个 SM,总计 ~30MB 片上 SRAM
Flash Attention 核心:在 共享内存 (SMEM) 中分块计算 Softmax,避免将 N×N 矩阵写入 HBM
特性 标准 Attention Flash Attention
显存复杂度 O(N²) - 存储完整注意力矩阵 O(N) - 分块计算,不存储中间结果
序列长度 8K ~16 GB ~0.1 GB
计算效率 ~30% GPU 利用率 ~70% GPU 利用率

使用方式:

# 方式1: PyTorch 2.0+ 原生支持
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    output = F.scaled_dot_product_attention(q, k, v)

# 方式2: flash-attn 库 (推荐,功能更全)
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, causal=True)

# 方式3: Transformers 库自动启用
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16
)

3D 并行 + ZeRO + Flash 组合效果

优化组合 每卡显存 (70B模型) 说明
无优化 ~1170 GB 单卡不可能
3D 并行 (TP8×PP8) ~18 GB 基础配置
3D + Flash Attention ~15 GB 激活值大幅减少
3D + ZeRO-1 ~12 GB 优化器状态分片
3D + ZeRO-2 + Flash ~8-10 GB 推荐配置,可增大batch

💡 最佳实践:3D并行 + ZeRO-2 + Flash Attention + 激活检查点 = 最优显存/性能平衡

关键优化建议

# 推荐的训练配置(含 ZeRO + Flash)
config = {
    # 精度设置
    "precision": "bf16",

    # 3D 并行配置
    "tensor_parallel_size": 8,
    "pipeline_parallel_size": 8,
    "data_parallel_size": 16,

    # ZeRO 配置 (与数据并行结合)
    "zero_stage": 2,              # 0=DDP, 1=优化器分片, 2=+梯度分片, 3=+参数分片
    "zero_offload": False,        # True=卸载到CPU (省显存但慢)

    # Flash Attention 配置
    "flash_attention": True,
    "flash_attention_version": 2, # 1 或 2 (推荐2)

    # 激活检查点
    "activation_checkpointing": True,
    "checkpoint_granularity": "selective",  # full / selective

    # Batch 配置
    "micro_batch_size": 1,
    "gradient_accumulation_steps": 128,
    "global_batch_size": 2048,    # 16 × 128 × 1
}

# DeepSpeed 完整配置示例
deepspeed_config = {
    "bf16": {"enabled": True},
    "zero_optimization": {
        "stage": 2,
        "overlap_comm": True,           # 通信计算重叠
        "contiguous_gradients": True,
        "reduce_bucket_size": 5e8,
    },
    "activation_checkpointing": {
        "partition_activations": True,
        "contiguous_memory_optimization": True,
    },
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 128,
}

# PyTorch 原生配置 (FSDP + Flash)
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision

fsdp_config = {
    "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,  # ZeRO-2
    "mixed_precision": MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16,
    ),
    "use_orig_params": True,  # 兼容Flash Attention
}

📌 版本说明