from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List, Optional import torch import torch.distributed as dist import torch.nn as nn from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.scheduler import Scheduler from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather from transformers import AutoModelForCausalLM from specforge.distributed import get_tp_device_mesh, get_tp_group from specforge.utils import padding from .sglang_backend import SGLangRunner @dataclass class DFlashTargetOutput: hidden_states: torch.Tensor # [batch, seq_len, hidden_size] input_ids: torch.Tensor # [batch, seq_len] attention_mask: torch.Tensor # [batch, seq_len] loss_mask: torch.Tensor # [batch, seq_len] class DFlashTargetModel(ABC): """ Abstract base class for DFlash target model backend. """ def __init__(self): self.capture_layer_ids = None @classmethod @abstractmethod def from_pretrained( cls, pretrained_model_name_or_path: str, torch_dtype: torch.dtype = None, device: str = None, cache_dir: Optional[str] = None, **kwargs, ) -> "DFlashTargetModel": """Initialize the target model backend.""" @abstractmethod def generate_dflash_data( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, ) -> DFlashTargetOutput: """Generate context hidden states for DFlash training.""" def set_capture_layers(self, layer_ids: List[int]) -> None: """Set which layers' hidden states to capture.""" self.capture_layer_ids = layer_ids class SGLangDFlashTargetModel(DFlashTargetModel): def __init__(self, model_runner: SGLangRunner): super().__init__() self.model_runner = model_runner @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, torch_dtype: torch.dtype = None, device: str = None, cache_dir: Optional[str] = None, trust_remote_code: bool = False, **kwargs, ) -> "SGLangDFlashTargetModel": tp_size = dist.get_world_size(get_tp_group()) server_args = ServerArgs( model_path=pretrained_model_name_or_path, trust_remote_code=trust_remote_code, dtype=torch_dtype, enable_return_hidden_states=True, # Critical for DFlash disable_cuda_graph=True, tp_size=tp_size, pp_size=1, **kwargs, ) tp_rank = dist.get_rank(get_tp_group()) moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) model_config = ModelConfig.from_server_args(server_args) model_runner = SGLangRunner( model_config=model_config, mem_fraction_static=server_args.mem_fraction_static, gpu_id=torch.cuda.current_device(), tp_rank=dist.get_rank(get_tp_group()), tp_size=server_args.tp_size, moe_ep_rank=moe_ep_rank, moe_ep_size=server_args.ep_size, pp_rank=0, pp_size=1, server_args=server_args, nccl_port=None, ) return cls(model_runner) def set_capture_layers(self, layer_ids: List[int]) -> None: super().set_capture_layers(layer_ids) # Note: We need to ensure SGLang supports custom capture layers. # Eagle3 implementation uses `set_eagle3_layers_to_capture`. # For DFlash, we might need to rely on `output_hidden_states=True` returning all layers # and then filtering, OR implementing `set_custom_layers_to_capture` in SGLang patch. # Assuming we can use the same mechanism or general mechanism if available. # If SGLang doesn't support selective capture easily, we might get all and select later. # But for memory efficiency, selective capture is better. # Checking Eagle3 implementation again: it calls `model.set_eagle3_layers_to_capture`. # This implies SGLang model wrapper has this method patched. # We will try to use a similar approach or assume we get full hidden states. # For now, let's assume we capture what's needed. if hasattr(self.model_runner.model, "set_eagle3_layers_to_capture"): self.model_runner.model.set_eagle3_layers_to_capture(layer_ids) @torch.no_grad def _extend(self, reqs): # Similar to Eagle3 _extend but simplified for just hidden states cache_params = CacheInitParams( disable=False, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, page_size=self.model_runner.server_args.page_size, ) tree_cache = RadixCache(cache_params) batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, tree_cache=tree_cache, model_config=self.model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, ) batch.prepare_for_extend() if require_mlp_sync(self.model_runner.server_args): Scheduler.prepare_mlp_sync_batch_raw( batch, dp_size=self.model_runner.server_args.dp_size, attn_tp_size=1, tp_group=self.model_runner.tp_group, get_idle_batch=None, disable_cuda_graph=self.model_runner.server_args.disable_cuda_graph, spec_algorithm=SpeculativeAlgorithm.NONE, speculative_num_draft_tokens=None, require_mlp_tp_gather=require_mlp_tp_gather( self.model_runner.server_args ), disable_overlap_schedule=self.model_runner.server_args.disable_overlap_schedule, offload_tags=set(), ) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch.capture_hidden_mode = CaptureHiddenMode.FULL output, _ = self.model_runner.forward(forward_batch) # Eagle3 output has aux_hidden_states. # We need to check what SGLang returns. Typically it returns 'hidden_states' or 'aux_hidden_states'. # Assuming it aligns with Eagle3 patch. input_lens = [len(req.origin_input_ids) for req in reqs] # Split per request if ( hasattr(output, "aux_hidden_states") and output.aux_hidden_states is not None ): hidden_states_list = torch.split( output.aux_hidden_states, input_lens, dim=0 ) elif hasattr(output, "hidden_states") and output.hidden_states is not None: hidden_states_list = torch.split(output.hidden_states, input_lens, dim=0) else: raise ValueError("SGLang output does not contain hidden states.") self.model_runner.req_to_token_pool.clear() self.model_runner.token_to_kv_pool_allocator.clear() return hidden_states_list @torch.no_grad() def generate_dflash_data( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, ) -> DFlashTargetOutput: sampling_params = SamplingParams(temperature=0, max_new_tokens=1) reqs, data_cache = [], [] if isinstance(input_ids, torch.Tensor): input_ids_list = torch.split(input_ids, 1, dim=0) attn_mask_list = torch.split(attention_mask, 1, dim=0) loss_mask_list = torch.split(loss_mask, 1, dim=0) for idx, (curr_ids, curr_attn, curr_loss) in enumerate( zip(input_ids_list, attn_mask_list, loss_mask_list) ): req = Req( rid=str(idx), origin_input_text="", origin_input_ids=curr_ids.view(-1).tolist(), sampling_params=sampling_params, ) req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) data_cache.append((curr_ids, curr_attn, curr_loss)) reqs.append(req) hidden_states_list = self._extend(reqs) # Stack back to batch hidden_states = torch.cat([h.unsqueeze(0) for h in hidden_states_list], dim=0) input_ids = torch.cat([d[0] for d in data_cache], dim=0) attention_mask = torch.cat([d[1] for d in data_cache], dim=0) loss_mask = torch.cat([d[2] for d in data_cache], dim=0) # Padding might be needed if batching varied lengths (but usually fixed length training) hidden_states = padding(hidden_states, left=False) input_ids = padding(input_ids, left=False) return DFlashTargetOutput( hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, ) class HFDFlashTargetModel(DFlashTargetModel): def __init__(self, model: nn.Module): super().__init__() self.model = model @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, torch_dtype: torch.dtype = None, device: str = None, cache_dir: Optional[str] = None, trust_remote_code: bool = True, **kwargs, ) -> "HFDFlashTargetModel": tp_size = get_tp_group().size() if tp_size > 1: device_kwargs = { "tp_plan": "auto", "tp_size": tp_size, "device_mesh": get_tp_device_mesh(), } else: device_kwargs = { "device_map": device, } target_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, torch_dtype=torch_dtype, cache_dir=cache_dir, output_hidden_states=True, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2", **device_kwargs, **kwargs, ).eval() return cls(target_model) @torch.no_grad() def generate_dflash_data( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, ) -> DFlashTargetOutput: outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False, ) # Extract selected layers # outputs.hidden_states is a tuple of (L+1) tensors # Indices in self.capture_layer_ids correspond to 0-based index of transformer layers. # outputs.hidden_states[0] is embedding output (usually). # Typically hidden_states[i+1] is output of layer i. offset = 1 selected = [] if self.capture_layer_ids is not None: for idx in self.capture_layer_ids: selected.append(outputs.hidden_states[idx + offset]) hidden_states = torch.cat(selected, dim=-1) else: # Fallback if no layers specified (maybe return last?) hidden_states = outputs.hidden_states[-1] return DFlashTargetOutput( hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, ) def get_dflash_target_model( pretrained_model_name_or_path: str, backend: str = "sglang", torch_dtype: torch.dtype = None, device: str = None, cache_dir: Optional[str] = None, **kwargs, ) -> DFlashTargetModel: if backend == "sglang": return SGLangDFlashTargetModel.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, torch_dtype=torch_dtype, device=device, cache_dir=cache_dir, **kwargs, ) elif backend == "hf": return HFDFlashTargetModel.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, torch_dtype=torch_dtype, device=device, cache_dir=cache_dir, **kwargs, ) else: raise ValueError(f"Invalid backend: {backend}")