阅读量 ,评论量

verl 框架 PPO 训练流程完整分析


一、整体执行链路

submit.sh
  └─ ray job submit ... python -u -m verl.trainer.main_ppo
       └─ main()                          # Hydra 入口
            └─ run_ppo(config)            # 解析配置
                 └─ TaskRunner.run()      # 注册 Worker、创建 Trainer、启动训练
                      ├─ add_actor_rollout_worker()
                      ├─ add_critic_worker()
                      ├─ add_ref_policy_worker()
                      ├─ add_reward_model_resource_pool()
                      └─ RayPPOTrainer(...)
                           ├─ init_workers()    # 创建 Worker、初始化模型
                           └─ fit()             # 训练循环

二、main_ppo.py 执行逻辑

文件: verl/trainer/main_ppo.py

2.1 三层结构

层级 函数 职责
入口层 main() Hydra 装饰器解析 YAML 配置
调度层 run_ppo(config) 判断是否使用 Ray,分发到对应执行器
执行层 TaskRunner.run() 注册所有 Worker、创建数据集和 Trainer、启动训练

2.2 TaskRunner.run() 核心流程

# 1. 注册各模型的 Worker 和资源池
self.add_actor_rollout_worker()      # Actor + Rollout (+ Ref)
self.add_critic_worker()             # Critic (如果需要)
self.add_ref_policy_worker()         # Ref Policy (如果需要且不与 Actor 共存)
self.add_reward_model_resource_pool() # Reward Model (如果需要)

# 2. 创建数据集和 tokenizer
train_dataset, val_dataset = ...
tokenizer = hf_tokenizer(...)

# 3. 创建并启动 Trainer
trainer = RayPPOTrainer(config, ...)
trainer.init_workers()
trainer.fit()

三、PPO 训练涉及的模型

3.1 四种模型角色

模型 角色 模型类 训练/推理 是否必需
Actor 策略模型,生成 token 并计算 log_prob AutoModelForCausalLM 训练
Rollout 推理引擎,与 Actor 共享权重 vLLM/SGLang/HF 推理
Critic 价值模型,估计每个 token 的 value AutoModelForTokenClassification 训练 条件必需
Ref Policy 参考策略,计算 KL 散度 AutoModelForCausalLM 推理 条件必需
Reward Model 打分模型 自定义 推理 可选

3.2 模型启用条件

文件: verl/trainer/ppo/utils.py

# Critic: adv_estimator == "gae"(默认值)时启用,或 critic.enable=True
def need_critic(config):
    return config.critic.enable or config.algorithm.adv_estimator == "gae"

# Ref Policy: 使用 KL reward 惩罚 或 KL loss 时启用
def need_reference_policy(config):
    return config.algorithm.use_kl_in_reward or config.actor.use_kl_loss

# Reward Model: reward.reward_model.enable=True 时启用(默认 False)
def need_reward_model(config):
    return config.reward.reward_model.enable

3.3 Worker 类实现位置

Worker 类 文件 说明
ActorRolloutRefWorker verl/workers/fsdp_workers.py:143 管理 Actor + Rollout + Ref
CriticWorker verl/workers/fsdp_workers.py:1274 管理 Critic
DataParallelPPOActor verl/workers/actor/ Actor 的 PPO 训练逻辑封装
DataParallelPPOCritic verl/workers/critic/ Critic 的 PPO 训练逻辑封装

3.4 Role 枚举

class Role(Enum):
    Actor = 0
    Rollout = 1
    ActorRollout = 2
    Critic = 3
    RefPolicy = 4
    RewardModel = 5
    ActorRolloutRef = 6

四、Worker 创建与分发机制

文件: verl/trainer/ppo/ray_trainer.py L768-790, verl/single_controller/ray/base.py

4.1 三步流程

步骤1: create_colocated_worker_cls()  →  融合多个 Worker 类为一个 Ray Actor 类
步骤2: RayWorkerGroup(...)            →  分发到各节点各 GPU 创建实例
步骤3: wg_dict.spawn(prefix_set=...)  →  拆分为独立视图,方法名映射

4.2 详细说明

步骤 1: 融合(Colocate)

# class_dict 示例: {"actor_rollout": ActorRolloutRefWorker, "critic": CriticWorker, "ref": RefWorker}
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)

create_colocated_worker_cls() 位于 verl/single_controller/ray/base.py:981-1022,它: - 创建一个 WorkerDict 类 - 将各 Worker 的方法加上前缀(如 actor_rollout_init_modelcritic_init_model) - 合并到同一个类中,使多个 Worker 共存于一个 Ray Actor 进程

步骤 2: 分发

wg_dict = self.ray_worker_group_cls(
    resource_pool=resource_pool,
    ray_cls_with_init=RayClassWithInitArgs(cls=worker_dict_cls, ...),
    ...
)

RayWorkerGroup._init_with_resource_pool() 位于 base.py:531-574,双重循环:

for node in nodes:           # 遍历节点
    for gpu in gpus_per_node:  # 遍历每个节点的 GPU
        _create_worker()       # 创建一个 Ray Actor 实例

_create_worker() (L616-676) 为每个 Worker 设置环境变量(RANK, WORLD_SIZE, MASTER_ADDR 等),然后执行:

cls.options(**options).remote()  # 在目标 GPU 上创建 Ray Actor

步骤 3: 拆分视图

spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
# 结果: {"actor_rollout": WorkerGroup视图, "critic": WorkerGroup视图, ...}

spawn() (L711-744) 将融合的 WorkerGroup 拆分成独立视图,方法名映射回原始名称: - actor_rollout_init_modelinit_model(actor_rollout 视图) - critic_init_modelinit_model(critic 视图)

4.3 示例:1 node × 8 GPU

PlacementGroup: 8 个 bundle,每个绑定 1 GPU
  ├─ GPU 0: Ray Actor (WorkerDict 实例)
  │    ├─ ActorRolloutRefWorker (rank=0)
  │    ├─ CriticWorker (rank=0)
  │    └─ RefWorker (rank=0)  ← 如果 colocated
  ├─ GPU 1: Ray Actor (WorkerDict 实例)
  │    ├─ ActorRolloutRefWorker (rank=1)
  │    ├─ CriticWorker (rank=1)
  │    └─ RefWorker (rank=1)
  ...
  └─ GPU 7: Ray Actor (WorkerDict 实例)
       ├─ ActorRolloutRefWorker (rank=7)
       ├─ CriticWorker (rank=7)
       └─ RefWorker (rank=7)

4.4 资源池管理


五、训练循环 RayPPOTrainer.fit()

文件: verl/trainer/ppo/ray_trainer.py L1240-1621

5.1 每个训练 step 的 7 个阶段

┌─────────────────────────────────────────────────────┐
│                  Training Step                       │
│                                                     │
│  Stage 1: Rollout (生成响应)                         │
│    actor_rollout → generate_sequences()             │
│    输入: prompts                                     │
│    输出: responses, old_log_probs (可选)             │
│                                                     │
│  Stage 2: Reward (计算奖励)                          │
│    reward_model → compute_rm_score() 或自定义函数     │
│    输入: prompts + responses                         │
│    输出: rewards                                     │
│                                                     │
│  Stage 3: Old Log Prob (计算旧策略概率)              │
│    actor_rollout → compute_log_prob()                │
│    输入: prompts + responses                         │
│    输出: old_log_probs                               │
│                                                     │
│  Stage 4: Ref Log Prob (计算参考策略概率)            │
│    ref_policy → compute_ref_log_prob()               │
│    输入: prompts + responses                         │
│    输出: ref_log_probs (用于 KL 惩罚/损失)          │
│                                                     │
│  Stage 5: Values (计算价值估计)                      │
│    critic → compute_values()                         │
│    输入: prompts + responses                         │
│    输出: values                                      │
│                                                     │
│  Stage 6: Advantage (计算优势函数)                   │
│    本地计算 GAE / GRPO 等                            │
│    输入: rewards, values, ref_log_probs              │
│    输出: advantages, returns                         │
│                                                     │
│  Stage 7: Update Policy (更新模型参数)               │
│    actor → update_policy()   # PPO loss              │
│    critic → update_critic()  # Value loss            │
│    输入: old_log_probs, advantages, returns          │
│    输出: 更新后的 Actor & Critic 权重                │
│                                                     │
└─────────────────────────────────────────────────────┘

5.2 数据流向

prompts ──→ [Rollout] ──→ responses
                              │
              ┌───────────────┼───────────────┐
              ▼               ▼               ▼
         [Reward]     [Actor log_prob]   [Ref log_prob]
              │               │               │
              ▼               ▼               ▼
           rewards      old_log_probs    ref_log_probs
              │               │               │
              └───────┬───────┘               │
                      ▼                       │
                  [Critic]                    │
                      │                       │
                      ▼                       │
                   values                     │
                      │                       │
              ┌───────┴───────────────────────┘
              ▼
         [GAE/GRPO 计算]
              │
              ▼
      advantages, returns
              │
       ┌──────┴──────┐
       ▼              ▼
 [Actor Update]  [Critic Update]

六、init_model() 流程详解

6.1 调用顺序

ray_trainer.pyinit_workers() 中:

# 1. 先初始化 Critic(如果需要)
critic_wg.init_model()

# 2. 再初始化 Ref(如果需要)
ref_wg.init_model()

# 3. 最后初始化 Actor + Rollout
actor_rollout_wg.init_model()

这个顺序确保 colocated 场景下显存分配合理。

6.2 ActorRolloutRefWorker.init_model()

文件: verl/workers/fsdp_workers.py:851-950

通过 @register(dispatch_mode=Dispatch.ONE_TO_ALL) 装饰,所有 GPU 上的 Worker 同时执行。

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
    # 0. 导入外部库(自定义模型注册)
    import_external_libs(self.config.model.get("external_lib", None))

    # 1. 构建 Actor 模型 (如果是 Actor 或 Rollout 角色)
    if self._is_actor or self._is_rollout:
        (self.actor_module_fsdp, self.actor_optimizer,
         self.actor_lr_scheduler, self.actor_model_config
        ) = self._build_model_optimizer(model_path=..., role="actor", ...)

    # 2. 封装为 DataParallelPPOActor
    if self._is_actor:
        self.actor = DataParallelPPOActor(config=..., actor_module=self.actor_module_fsdp, ...)

    # 3. 构建 Rollout 引擎 (vLLM/SGLang/HF)
    if self._is_rollout:
        self._build_rollout(trust_remote_code=...)

    # 4. 构建 Ref 模型 (复用 _build_model_optimizer, role="ref")
    if self._is_ref:
        self.ref_module_fsdp = self._build_model_optimizer(role="ref", optim_config=None, ...)

6.3 _build_model_optimizer() — 8 阶段流水线

文件: verl/workers/fsdp_workers.py:330-676

模型路径 → Tokenizer → AutoConfig → 模型类检测 → from_pretrained → Monkey Patch → LoRA → FSDP → Optimizer
阶段 行号 操作 说明
1 L371-378 加载 Tokenizer & Processor hf_tokenizer(), hf_processor()
2 L388-425 加载模型配置 AutoConfig.from_pretrained() + override_config 覆盖
3 L436-460 自动检测模型类 Vision2Seq → CausalLM → ImageTextToText → AutoModel
4 L462-468 加载预训练权重 actor_module_class.from_pretrained()
5 L481-490 Monkey Patch remove_padding, fused_kernels, tiled_mlp, 序列并行
6 L498-527 LoRA 适配 加载已有 adapter 或新建 LoRA
7 L594-632 FSDP 包装 FSDP1 或 FSDP2,Actor 不 offload,Ref 开启 CPU offload
8 L640-675 构建 Optimizer 仅 role=“actor”,支持 constant/cosine LR scheduler

阶段 3 模型类检测的详细逻辑:

# 优先级 1: 检查 config.auto_map (remote code)
if has_remote_code:
    match auto_class:
        case "AutoModelForVision2Seq":    → AutoModelForVision2Seq
        case "AutoModelForCausalLM":      → AutoModelForCausalLM
        case "AutoModelForImageTextToText": → AutoModelForImageTextToText
        case _:                           → AutoModel

# 优先级 2: 从 HF _model_mapping 注册表匹配
else:
    if config in AutoModelForVision2Seq._model_mapping:    → AutoModelForVision2Seq
    elif config in AutoModelForCausalLM._model_mapping:    → AutoModelForCausalLM
    elif config in AutoModelForImageTextToText._model_mapping: → AutoModelForImageTextToText
    else:                                                  → AutoModel

Actor vs Ref 的差异:

对比项 Actor (role=“actor”) Ref (role=“ref”)
torch_dtype fp32 bf16
Optimizer 无 (optim_config=None)
CPU Offload 关闭 开启 (CPUOffload(offload_params=True))
FSDP forward_only False True
QAT 可选 不应用

6.4 _build_rollout() — Rollout 引擎构建

文件: verl/workers/fsdp_workers.py:679-742

# 1. 构建推理 device_mesh
infer_tp = config.rollout.tensor_model_parallel_size * config.rollout.data_parallel_size
infer_pp = config.rollout.pipeline_model_parallel_size
rollout_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, infer_tp, infer_pp))

# 2. 从注册表获取 Rollout 类并实例化
self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)(
    config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh
)

get_rollout_class() 位于 verl/workers/rollout/base.py:90-104,通过 _ROLLOUT_REGISTRY 注册表查找: - (vllm, async) → vLLM AsyncLLM - (sglang, async) → SGLang - (hf, sync) → HuggingFace 原生

重要: Rollout 引擎和 Actor 共享同一个模型权重(Hybrid Engine),通过 rollout_mode() / trainer_mode() 上下文切换。

6.5 CriticWorker.init_model()

文件: verl/workers/fsdp_workers.py:1607-1629

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
    import_external_libs(self.config.model.get("external_lib", None))

    # 1. 构建 Critic 模型
    self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = \
        self._build_critic_model_optimizer(self.config)

    # 2. 封装为 DataParallelPPOCritic
    self.critic = DataParallelPPOCritic(
        config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer
    )

6.6 _build_critic_model_optimizer() — Critic 模型构建

文件: verl/workers/fsdp_workers.py:1358-1605

与 Actor 的核心差异:

对比项 Actor Critic
模型类 AutoModelForCausalLM AutoModelForTokenClassification
输出 下一个 token 的概率分布 (vocab_size) 每个 token 的标量 value (num_labels=1)
加载函数 from_pretrained() load_valuehead_model()
LoRA task_type TaskType.CAUSAL_LM TaskType.TOKEN_CLS

关键代码:

critic_model_config.num_labels = 1  # Value head 输出维度为 1

critic_module = load_valuehead_model(
    local_path, torch_dtype, critic_model_config, trust_remote_code
)

6.7 load_valuehead_model() — 带 Value Head 的模型加载

文件: verl/utils/model.py:628-667

def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code):
    try:
        # 方案1: AutoModelForTokenClassification(将 LM head 替换为 nn.Linear(hidden_size, 1))
        model = AutoModelForTokenClassification.from_pretrained(...)
        return model
    except:
        # 方案2: trl 的 AutoModelForCausalLMWithValueHead(在 CausalLM 基础上加 value head)
        ori_model = AutoModelForCausalLM.from_pretrained(...)
        model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model)
        return model

七、如何修改模型结构

方法 1: YAML override_config(最简单)

适用场景: 修改 HuggingFace config 中已有的参数。

actor_rollout_ref:
  model:
    override_config:
      num_hidden_layers: 24
      hidden_size: 2048
      attn_implementation: flash_attention_2

对应代码: fsdp_workers.py:412-425

方法 2: external_lib 注册自定义模型(推荐)

适用场景: 使用全新的模型架构。

actor_rollout_ref:
  model:
    external_lib: /path/to/your/custom_model_lib

步骤: 1. 创建自定义模型文件,继承 PreTrainedModel,注册到 HuggingFace AutoModel 2. 在模型的 config.json 中设置 auto_map 指向自定义类 3. YAML 中配置 external_lib 路径

对应代码: fsdp_workers.py:855import_external_libs()

方法 3: trust_remote_code + 自定义 modeling 文件

适用场景: 模型仓库自带 modeling_xxx.py

actor_rollout_ref:
  model:
    trust_remote_code: true
    path: /path/to/model_with_custom_code

代码自动检测 config.auto_map 字段加载对应模型类 (fsdp_workers.py:436-460)。

方法 4: 自定义 Critic Value Head

适用场景: 修改 Critic 的输出头。

方法 5: Monkey Patch 修改运行时行为

适用场景: 修改 attention、MLP 等计算逻辑。

框架内置支持: - use_remove_padding: 移除 padding 加速 - use_fused_kernels: FlashAttention 等融合算子 - use_tiled_mlp: 分块 MLP 计算 - ulysses_sp_size: Ulysses 序列并行

也可在 external_lib 中做自定义 monkey patch。

方法 6: 直接修改框架源码(最后手段)

关键修改点: - Actor: fsdp_workers.py:462-468from_pretrained() 之后、FSDP 包装之前 - Critic: fsdp_workers.py:1431-1436load_valuehead_model() 之后、FSDP 包装之前


八、关键配置文件

配置文件 说明
verl/trainer/config/ppo_trainer.yaml PPO 训练主配置,默认 GAE 优势估计
verl/trainer/config/model/hf_model.yaml 模型配置,默认 Qwen/Qwen2.5-0.5B-Instruct,LoRA/MTP 等
verl/trainer/config/actor/actor.yaml Actor 通用配置:PPO 超参、clip_ratio、loss 等
verl/trainer/config/actor/dp_actor.yaml FSDP Actor 配置:FSDP 策略、gradient clipping、QAT
verl/trainer/config/critic/dp_critic.yaml FSDP Critic 配置
verl/trainer/config/ref/dp_ref.yaml Ref Policy 配置:forward_only=True
verl/trainer/config/reward/reward.yaml Reward 配置:reward_model.enable 默认 False
verl/trainer/config/rollout/rollout.yaml Rollout 配置:name=vllm, mode=async, load_format=dummy

九、关键源码文件

文件 说明
verl/trainer/main_ppo.py 主入口:main() → run_ppo() → TaskRunner.run()
verl/trainer/ppo/ray_trainer.py RayPPOTrainer:init_workers() + fit() 训练循环
verl/trainer/ppo/utils.py need_critic(), need_reference_policy() 等判断函数
verl/workers/fsdp_workers.py ActorRolloutRefWorker + CriticWorker 的完整实现
verl/utils/model.py load_valuehead_model(), update_model_config() 等工具
verl/workers/rollout/base.py get_rollout_class() Rollout 引擎注册表
verl/workers/actor/ DataParallelPPOActor 训练逻辑
verl/workers/critic/ DataParallelPPOCritic 训练逻辑
verl/single_controller/ray/base.py create_colocated_worker_cls(), RayWorkerGroup, ResourcePoolManager