| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
|
|
| import pytest |
| import torch |
| import torch.distributed |
| import torch.multiprocessing as mp |
| from peft import LoraConfig, get_peft_model |
| from torch.distributed import init_device_mesh |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType |
| from transformers import AutoModelForCausalLM, Qwen3Config |
|
|
| from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device |
| from verl.utils.fsdp_utils import ( |
| MixedPrecisionPolicy, |
| apply_fsdp2, |
| get_fsdp_wrap_policy, |
| normalize_peft_param_name, |
| ) |
| from verl.utils.model import convert_weight_keys |
|
|
|
|
| def _test_normalize_peft_with_fsdp_worker(rank, world_size, rendezvous_file, strategy): |
| """Worker function for testing normalize_peft_param_name with FSDP-wrapped models. |
| |
| Args: |
| rank: Process rank |
| world_size: Total number of processes |
| rendezvous_file: Path to rendezvous file for distributed init |
| strategy: FSDP strategy ("fsdp" or "fsdp2") |
| """ |
| get_torch_device().set_device(rank) |
| torch.distributed.init_process_group( |
| backend=get_nccl_backend(), |
| init_method=f"file://{rendezvous_file}", |
| rank=rank, |
| world_size=world_size, |
| ) |
| device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=("dp",)) |
|
|
| |
| config = Qwen3Config( |
| num_hidden_layers=2, |
| num_attention_heads=2, |
| num_key_value_heads=2, |
| hidden_size=128, |
| intermediate_size=256, |
| ) |
|
|
| |
| with torch.device(get_device_name()): |
| base_model = AutoModelForCausalLM.from_config( |
| config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" |
| ) |
| base_model = base_model.to(device=get_device_name()) |
|
|
| |
| lora_config = LoraConfig( |
| r=8, lora_alpha=16, target_modules="all-linear", lora_dropout=0.0, bias="none", task_type="CAUSAL_LM" |
| ) |
| peft_model = get_peft_model(base_model, lora_config) |
|
|
| |
| with torch.device(get_device_name()): |
| base_model_for_fsdp = AutoModelForCausalLM.from_config( |
| config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" |
| ) |
| base_model_for_fsdp = base_model_for_fsdp.to(device=get_device_name()) |
|
|
| if strategy == "fsdp": |
| mixed_precision = MixedPrecision( |
| param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 |
| ) |
|
|
| |
| fsdp_base_model = FSDP( |
| base_model_for_fsdp, |
| use_orig_params=True, |
| device_id=get_torch_device().current_device(), |
| sharding_strategy=ShardingStrategy.FULL_SHARD, |
| mixed_precision=mixed_precision, |
| device_mesh=device_mesh, |
| auto_wrap_policy=get_fsdp_wrap_policy(module=base_model_for_fsdp, is_lora=False), |
| ) |
|
|
| |
| fsdp_peft_model = FSDP( |
| peft_model, |
| use_orig_params=True, |
| device_id=get_torch_device().current_device(), |
| sharding_strategy=ShardingStrategy.FULL_SHARD, |
| mixed_precision=mixed_precision, |
| device_mesh=device_mesh, |
| auto_wrap_policy=get_fsdp_wrap_policy(module=peft_model, is_lora=True), |
| ) |
| else: |
| |
| mp_policy = MixedPrecisionPolicy( |
| param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True |
| ) |
| fsdp_kwargs = { |
| "mesh": device_mesh, |
| "mp_policy": mp_policy, |
| } |
|
|
| |
| apply_fsdp2(base_model_for_fsdp, fsdp_kwargs, {}) |
| fsdp_base_model = base_model_for_fsdp |
|
|
| |
| apply_fsdp2(peft_model, fsdp_kwargs, {}) |
| fsdp_peft_model = peft_model |
|
|
| |
| if strategy == "fsdp": |
| |
| with FSDP.state_dict_type(fsdp_base_model, StateDictType.FULL_STATE_DICT): |
| base_state_dict = fsdp_base_model.state_dict() |
|
|
| with FSDP.state_dict_type(fsdp_peft_model, StateDictType.FULL_STATE_DICT): |
| peft_state_dict = fsdp_peft_model.state_dict() |
| else: |
| |
| base_state_dict = fsdp_base_model.state_dict() |
| peft_state_dict = fsdp_peft_model.state_dict() |
|
|
| |
| normalized_peft_state_dict = normalize_peft_param_name(peft_state_dict) |
|
|
| base_state_dict = convert_weight_keys( |
| base_state_dict, getattr(fsdp_base_model, "_fsdp_wrapped_module", fsdp_base_model) |
| ) |
| normalized_peft_state_dict = convert_weight_keys( |
| normalized_peft_state_dict, getattr(fsdp_peft_model, "_fsdp_wrapped_module", fsdp_peft_model) |
| ) |
|
|
| |
| base_keys = set(base_state_dict.keys()) |
| normalized_peft_keys = set(normalized_peft_state_dict.keys()) |
|
|
| |
| print(f"\n=== FSDP {strategy} Test Results ===") |
| print(f"Base model keys: {base_keys=}") |
| print(f"Normalized PEFT keys: {normalized_peft_keys=}") |
|
|
| |
| missing_keys = base_keys - normalized_peft_keys |
| if missing_keys: |
| print(f"Missing keys from base model: {missing_keys}") |
|
|
| |
| extra_keys = normalized_peft_keys - base_keys |
| if extra_keys: |
| print(f"Extra keys not in base model: {extra_keys}") |
|
|
| |
| missing_keys = base_keys - normalized_peft_keys |
| assert len(missing_keys) == 0, f"Missing keys from base model: {missing_keys}" |
|
|
| |
| extra_keys = normalized_peft_keys - base_keys |
| assert len(extra_keys) == 0, f"Extra keys not in base model: {extra_keys}" |
|
|
| |
| assert base_keys == normalized_peft_keys, "Normalized PEFT keys should exactly match FSDP base model keys" |
|
|
| |
| for key in base_keys: |
| base_shape = base_state_dict[key].shape |
| peft_shape = normalized_peft_state_dict[key].shape |
| assert base_shape == peft_shape, f"Shape mismatch for {key}: base={base_shape}, peft={peft_shape}" |
|
|
| |
| lora_keys = [k for k in normalized_peft_keys if "lora_" in k or "adapter_" in k] |
| assert len(lora_keys) == 0, f"Normalized state dict should not contain LoRA keys, but found: {lora_keys}" |
|
|
| if rank == 0: |
| print(f"✓ All tests passed for FSDP {strategy}") |
|
|
| torch.distributed.barrier() |
| torch.distributed.destroy_process_group() |
|
|
|
|
| @pytest.mark.parametrize("world_size", (2,)) |
| @pytest.mark.parametrize("strategy", ("fsdp", "fsdp2")) |
| def test_normalize_peft_param_name_with_fsdp(world_size, strategy, tmp_path): |
| """Test normalize_peft_param_name with FSDP-wrapped models. |
| |
| This test verifies that after applying FSDP to both base and PEFT models, |
| the normalized PEFT model keys match the FSDP base model keys. |
| """ |
| rendezvous_file = str(tmp_path / f"rdzv_file_normalize_{strategy}") |
| os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) |
|
|
| mp.spawn( |
| fn=_test_normalize_peft_with_fsdp_worker, |
| args=(world_size, rendezvous_file, strategy), |
| nprocs=world_size, |
| join=True, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__, "-v"]) |
|
|