Hanrui / progress /SpecForge /specforge /modeling /target /dflash_target_model.py
Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather
from transformers import AutoModelForCausalLM
from specforge.distributed import get_tp_device_mesh, get_tp_group
from specforge.utils import padding
from .sglang_backend import SGLangRunner
@dataclass
class DFlashTargetOutput:
hidden_states: torch.Tensor # [batch, seq_len, hidden_size]
input_ids: torch.Tensor # [batch, seq_len]
attention_mask: torch.Tensor # [batch, seq_len]
loss_mask: torch.Tensor # [batch, seq_len]
class DFlashTargetModel(ABC):
"""
Abstract base class for DFlash target model backend.
"""
def __init__(self):
self.capture_layer_ids = None
@classmethod
@abstractmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
**kwargs,
) -> "DFlashTargetModel":
"""Initialize the target model backend."""
@abstractmethod
def generate_dflash_data(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
) -> DFlashTargetOutput:
"""Generate context hidden states for DFlash training."""
def set_capture_layers(self, layer_ids: List[int]) -> None:
"""Set which layers' hidden states to capture."""
self.capture_layer_ids = layer_ids
class SGLangDFlashTargetModel(DFlashTargetModel):
def __init__(self, model_runner: SGLangRunner):
super().__init__()
self.model_runner = model_runner
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
trust_remote_code: bool = False,
**kwargs,
) -> "SGLangDFlashTargetModel":
tp_size = dist.get_world_size(get_tp_group())
server_args = ServerArgs(
model_path=pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
dtype=torch_dtype,
enable_return_hidden_states=True, # Critical for DFlash
disable_cuda_graph=True,
tp_size=tp_size,
pp_size=1,
**kwargs,
)
tp_rank = dist.get_rank(get_tp_group())
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
model_config = ModelConfig.from_server_args(server_args)
model_runner = SGLangRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=torch.cuda.current_device(),
tp_rank=dist.get_rank(get_tp_group()),
tp_size=server_args.tp_size,
moe_ep_rank=moe_ep_rank,
moe_ep_size=server_args.ep_size,
pp_rank=0,
pp_size=1,
server_args=server_args,
nccl_port=None,
)
return cls(model_runner)
def set_capture_layers(self, layer_ids: List[int]) -> None:
super().set_capture_layers(layer_ids)
# Note: We need to ensure SGLang supports custom capture layers.
# Eagle3 implementation uses `set_eagle3_layers_to_capture`.
# For DFlash, we might need to rely on `output_hidden_states=True` returning all layers
# and then filtering, OR implementing `set_custom_layers_to_capture` in SGLang patch.
# Assuming we can use the same mechanism or general mechanism if available.
# If SGLang doesn't support selective capture easily, we might get all and select later.
# But for memory efficiency, selective capture is better.
# Checking Eagle3 implementation again: it calls `model.set_eagle3_layers_to_capture`.
# This implies SGLang model wrapper has this method patched.
# We will try to use a similar approach or assume we get full hidden states.
# For now, let's assume we capture what's needed.
if hasattr(self.model_runner.model, "set_eagle3_layers_to_capture"):
self.model_runner.model.set_eagle3_layers_to_capture(layer_ids)
@torch.no_grad
def _extend(self, reqs):
# Similar to Eagle3 _extend but simplified for just hidden states
cache_params = CacheInitParams(
disable=False,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
page_size=self.model_runner.server_args.page_size,
)
tree_cache = RadixCache(cache_params)
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
tree_cache=tree_cache,
model_config=self.model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
)
batch.prepare_for_extend()
if require_mlp_sync(self.model_runner.server_args):
Scheduler.prepare_mlp_sync_batch_raw(
batch,
dp_size=self.model_runner.server_args.dp_size,
attn_tp_size=1,
tp_group=self.model_runner.tp_group,
get_idle_batch=None,
disable_cuda_graph=self.model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(
self.model_runner.server_args
),
disable_overlap_schedule=self.model_runner.server_args.disable_overlap_schedule,
offload_tags=set(),
)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.capture_hidden_mode = CaptureHiddenMode.FULL
output, _ = self.model_runner.forward(forward_batch)
# Eagle3 output has aux_hidden_states.
# We need to check what SGLang returns. Typically it returns 'hidden_states' or 'aux_hidden_states'.
# Assuming it aligns with Eagle3 patch.
input_lens = [len(req.origin_input_ids) for req in reqs]
# Split per request
if (
hasattr(output, "aux_hidden_states")
and output.aux_hidden_states is not None
):
hidden_states_list = torch.split(
output.aux_hidden_states, input_lens, dim=0
)
elif hasattr(output, "hidden_states") and output.hidden_states is not None:
hidden_states_list = torch.split(output.hidden_states, input_lens, dim=0)
else:
raise ValueError("SGLang output does not contain hidden states.")
self.model_runner.req_to_token_pool.clear()
self.model_runner.token_to_kv_pool_allocator.clear()
return hidden_states_list
@torch.no_grad()
def generate_dflash_data(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
) -> DFlashTargetOutput:
sampling_params = SamplingParams(temperature=0, max_new_tokens=1)
reqs, data_cache = [], []
if isinstance(input_ids, torch.Tensor):
input_ids_list = torch.split(input_ids, 1, dim=0)
attn_mask_list = torch.split(attention_mask, 1, dim=0)
loss_mask_list = torch.split(loss_mask, 1, dim=0)
for idx, (curr_ids, curr_attn, curr_loss) in enumerate(
zip(input_ids_list, attn_mask_list, loss_mask_list)
):
req = Req(
rid=str(idx),
origin_input_text="",
origin_input_ids=curr_ids.view(-1).tolist(),
sampling_params=sampling_params,
)
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
data_cache.append((curr_ids, curr_attn, curr_loss))
reqs.append(req)
hidden_states_list = self._extend(reqs)
# Stack back to batch
hidden_states = torch.cat([h.unsqueeze(0) for h in hidden_states_list], dim=0)
input_ids = torch.cat([d[0] for d in data_cache], dim=0)
attention_mask = torch.cat([d[1] for d in data_cache], dim=0)
loss_mask = torch.cat([d[2] for d in data_cache], dim=0)
# Padding might be needed if batching varied lengths (but usually fixed length training)
hidden_states = padding(hidden_states, left=False)
input_ids = padding(input_ids, left=False)
return DFlashTargetOutput(
hidden_states=hidden_states,
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
)
class HFDFlashTargetModel(DFlashTargetModel):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
trust_remote_code: bool = True,
**kwargs,
) -> "HFDFlashTargetModel":
tp_size = get_tp_group().size()
if tp_size > 1:
device_kwargs = {
"tp_plan": "auto",
"tp_size": tp_size,
"device_mesh": get_tp_device_mesh(),
}
else:
device_kwargs = {
"device_map": device,
}
target_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
output_hidden_states=True,
trust_remote_code=trust_remote_code,
attn_implementation="flash_attention_2",
**device_kwargs,
**kwargs,
).eval()
return cls(target_model)
@torch.no_grad()
def generate_dflash_data(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
) -> DFlashTargetOutput:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Extract selected layers
# outputs.hidden_states is a tuple of (L+1) tensors
# Indices in self.capture_layer_ids correspond to 0-based index of transformer layers.
# outputs.hidden_states[0] is embedding output (usually).
# Typically hidden_states[i+1] is output of layer i.
offset = 1
selected = []
if self.capture_layer_ids is not None:
for idx in self.capture_layer_ids:
selected.append(outputs.hidden_states[idx + offset])
hidden_states = torch.cat(selected, dim=-1)
else:
# Fallback if no layers specified (maybe return last?)
hidden_states = outputs.hidden_states[-1]
return DFlashTargetOutput(
hidden_states=hidden_states,
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
)
def get_dflash_target_model(
pretrained_model_name_or_path: str,
backend: str = "sglang",
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
**kwargs,
) -> DFlashTargetModel:
if backend == "sglang":
return SGLangDFlashTargetModel.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
torch_dtype=torch_dtype,
device=device,
cache_dir=cache_dir,
**kwargs,
)
elif backend == "hf":
return HFDFlashTargetModel.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
torch_dtype=torch_dtype,
device=device,
cache_dir=cache_dir,
**kwargs,
)
else:
raise ValueError(f"Invalid backend: {backend}")