EVAFRILL-Mo-3B / scripts /merge_checkpoints.py
pathcosmos's picture
Upload folder using huggingface_hub
29fc577 verified
#!/usr/bin/env python3
"""
scripts/merge_checkpoints.py — Slerp (Spherical Linear Interpolation) checkpoint merge.
Merges two model checkpoints (e.g., SFT + DPO) using SLERP interpolation
to balance knowledge retention and alignment improvement.
Reference: Nemotron-H paper — SLERP merging reduces alignment tax.
Usage:
python scripts/merge_checkpoints.py \
--ckpt_a checkpoints/3b_sft_v2/checkpoint-best \
--ckpt_b checkpoints/3b_dpo/checkpoint-merged \
--output checkpoints/3b_dpo/checkpoint-slerp \
--alpha 0.5
alpha=0.0 → pure ckpt_a (SFT)
alpha=1.0 → pure ckpt_b (DPO)
alpha=0.5 → equal blend (recommended starting point)
"""
from __future__ import annotations
import argparse
import math
import shutil
from pathlib import Path
import torch
import yaml
def slerp(t: float, v0: torch.Tensor, v1: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""Spherical linear interpolation between two tensors.
Args:
t: Interpolation factor in [0, 1]. 0 → v0, 1 → v1.
v0: First tensor (flattened internally).
v1: Second tensor (same shape as v0).
eps: Small value to avoid division by zero.
Returns:
Interpolated tensor with the same shape as v0.
"""
original_shape = v0.shape
v0_flat = v0.flatten().float()
v1_flat = v1.flatten().float()
# Normalize
v0_norm = v0_flat / (v0_flat.norm() + eps)
v1_norm = v1_flat / (v1_flat.norm() + eps)
# Cosine of angle between vectors
cos_omega = torch.dot(v0_norm, v1_norm).clamp(-1.0, 1.0)
# If vectors are very similar, fall back to linear interpolation
if abs(cos_omega.item()) > 0.9995:
result = (1.0 - t) * v0_flat + t * v1_flat
return result.reshape(original_shape).to(v0.dtype)
omega = torch.acos(cos_omega)
sin_omega = torch.sin(omega)
s0 = torch.sin((1.0 - t) * omega) / sin_omega
s1 = torch.sin(t * omega) / sin_omega
# Interpolate using original (non-normalized) vectors to preserve scale
result = s0 * v0_flat + s1 * v1_flat
return result.reshape(original_shape).to(v0.dtype)
def lerp(t: float, v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
"""Simple linear interpolation."""
return ((1.0 - t) * v0.float() + t * v1.float()).to(v0.dtype)
def merge_state_dicts(
sd_a: dict[str, torch.Tensor],
sd_b: dict[str, torch.Tensor],
alpha: float = 0.5,
method: str = "slerp",
) -> dict[str, torch.Tensor]:
"""Merge two state dicts using SLERP or LERP.
Args:
sd_a: State dict A (e.g., SFT model).
sd_b: State dict B (e.g., DPO model).
alpha: Interpolation factor. 0 → A, 1 → B.
method: "slerp" or "lerp".
Returns:
Merged state dict.
"""
interp_fn = slerp if method == "slerp" else lerp
merged = {}
keys_a = set(sd_a.keys())
keys_b = set(sd_b.keys())
common = keys_a & keys_b
only_a = keys_a - keys_b
only_b = keys_b - keys_a
if only_a:
print(f"[WARN] {len(only_a)} keys only in ckpt_a (kept as-is)")
if only_b:
print(f"[WARN] {len(only_b)} keys only in ckpt_b (kept as-is)")
for key in sorted(common):
va = sd_a[key]
vb = sd_b[key]
if va.shape != vb.shape:
print(f"[WARN] Shape mismatch for {key}: {va.shape} vs {vb.shape}, keeping ckpt_a")
merged[key] = va
continue
# Only interpolate float parameters (skip int buffers, etc.)
if va.is_floating_point() and va.numel() > 1:
merged[key] = interp_fn(alpha, va, vb)
else:
merged[key] = va # Keep from ckpt_a for non-float/scalar
# Include keys unique to each
for key in only_a:
merged[key] = sd_a[key]
for key in only_b:
merged[key] = sd_b[key]
return merged
def main():
parser = argparse.ArgumentParser(description="SLERP checkpoint merge")
parser.add_argument("--ckpt_a", type=Path, required=True,
help="Path to checkpoint A (e.g., SFT)")
parser.add_argument("--ckpt_b", type=Path, required=True,
help="Path to checkpoint B (e.g., DPO)")
parser.add_argument("--output", type=Path, required=True,
help="Output checkpoint directory")
parser.add_argument("--alpha", type=float, default=0.5,
help="Interpolation factor (0=A, 1=B, default 0.5)")
parser.add_argument("--method", choices=["slerp", "lerp"], default="slerp",
help="Interpolation method (default: slerp)")
args = parser.parse_args()
print(f"Merge: {args.ckpt_a.name} ←({1-args.alpha:.1%})— ({args.alpha:.1%})→ {args.ckpt_b.name}")
print(f"Method: {args.method}, alpha={args.alpha}")
# Load state dicts
print("Loading checkpoint A...")
sd_a = torch.load(args.ckpt_a / "model.pt", map_location="cpu", weights_only=True)
print(f" {len(sd_a)} keys loaded")
print("Loading checkpoint B...")
sd_b = torch.load(args.ckpt_b / "model.pt", map_location="cpu", weights_only=True)
print(f" {len(sd_b)} keys loaded")
# Merge
print("Merging...")
merged_sd = merge_state_dicts(sd_a, sd_b, alpha=args.alpha, method=args.method)
print(f" {len(merged_sd)} keys in merged state dict")
# Save
args.output.mkdir(parents=True, exist_ok=True)
torch.save(merged_sd, args.output / "model.pt")
# Copy config from ckpt_a
config_src = args.ckpt_a / "config.yaml"
if config_src.exists():
shutil.copy2(str(config_src), str(args.output / "config.yaml"))
# Copy tokenizer if available
for tok_name in ["tokenizer.json", "tokenizer.model"]:
tok_src = args.ckpt_a / tok_name
if tok_src.exists():
shutil.copy2(str(tok_src), str(args.output / tok_name))
# Write merge metadata
meta = {
"ckpt_a": str(args.ckpt_a),
"ckpt_b": str(args.ckpt_b),
"alpha": args.alpha,
"method": args.method,
}
with open(args.output / "merge_info.yaml", "w") as f:
yaml.safe_dump(meta, f)
size_mb = (args.output / "model.pt").stat().st_size / 1e6
print(f"\nMerged checkpoint saved → {args.output} ({size_mb:.0f} MB)")
if __name__ == "__main__":
main()