| from abc import ABC, abstractmethod |
| from dataclasses import dataclass |
| from typing import List, Optional |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| from sglang.srt.configs.model_config import ModelConfig |
| from sglang.srt.managers.schedule_batch import Req, ScheduleBatch |
| from sglang.srt.managers.scheduler import Scheduler |
| from sglang.srt.mem_cache.cache_init_params import CacheInitParams |
| from sglang.srt.mem_cache.radix_cache import RadixCache |
| from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch |
| from sglang.srt.sampling.sampling_params import SamplingParams |
| from sglang.srt.server_args import ServerArgs |
| from sglang.srt.speculative.spec_info import SpeculativeAlgorithm |
| from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather |
| from transformers import AutoModelForCausalLM |
|
|
| from specforge.distributed import get_tp_device_mesh, get_tp_group |
| from specforge.utils import padding |
|
|
| from .sglang_backend import SGLangRunner |
|
|
|
|
| @dataclass |
| class DFlashTargetOutput: |
| hidden_states: torch.Tensor |
| input_ids: torch.Tensor |
| attention_mask: torch.Tensor |
| loss_mask: torch.Tensor |
|
|
|
|
| class DFlashTargetModel(ABC): |
| """ |
| Abstract base class for DFlash target model backend. |
| """ |
|
|
| def __init__(self): |
| self.capture_layer_ids = None |
|
|
| @classmethod |
| @abstractmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: str, |
| torch_dtype: torch.dtype = None, |
| device: str = None, |
| cache_dir: Optional[str] = None, |
| **kwargs, |
| ) -> "DFlashTargetModel": |
| """Initialize the target model backend.""" |
|
|
| @abstractmethod |
| def generate_dflash_data( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| ) -> DFlashTargetOutput: |
| """Generate context hidden states for DFlash training.""" |
|
|
| def set_capture_layers(self, layer_ids: List[int]) -> None: |
| """Set which layers' hidden states to capture.""" |
| self.capture_layer_ids = layer_ids |
|
|
|
|
| class SGLangDFlashTargetModel(DFlashTargetModel): |
| def __init__(self, model_runner: SGLangRunner): |
| super().__init__() |
| self.model_runner = model_runner |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: str, |
| torch_dtype: torch.dtype = None, |
| device: str = None, |
| cache_dir: Optional[str] = None, |
| trust_remote_code: bool = False, |
| **kwargs, |
| ) -> "SGLangDFlashTargetModel": |
| tp_size = dist.get_world_size(get_tp_group()) |
| server_args = ServerArgs( |
| model_path=pretrained_model_name_or_path, |
| trust_remote_code=trust_remote_code, |
| dtype=torch_dtype, |
| enable_return_hidden_states=True, |
| disable_cuda_graph=True, |
| tp_size=tp_size, |
| pp_size=1, |
| **kwargs, |
| ) |
|
|
| tp_rank = dist.get_rank(get_tp_group()) |
| moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) |
| model_config = ModelConfig.from_server_args(server_args) |
|
|
| model_runner = SGLangRunner( |
| model_config=model_config, |
| mem_fraction_static=server_args.mem_fraction_static, |
| gpu_id=torch.cuda.current_device(), |
| tp_rank=dist.get_rank(get_tp_group()), |
| tp_size=server_args.tp_size, |
| moe_ep_rank=moe_ep_rank, |
| moe_ep_size=server_args.ep_size, |
| pp_rank=0, |
| pp_size=1, |
| server_args=server_args, |
| nccl_port=None, |
| ) |
| return cls(model_runner) |
|
|
| def set_capture_layers(self, layer_ids: List[int]) -> None: |
| super().set_capture_layers(layer_ids) |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| if hasattr(self.model_runner.model, "set_eagle3_layers_to_capture"): |
| self.model_runner.model.set_eagle3_layers_to_capture(layer_ids) |
|
|
| @torch.no_grad |
| def _extend(self, reqs): |
| |
| cache_params = CacheInitParams( |
| disable=False, |
| req_to_token_pool=self.model_runner.req_to_token_pool, |
| token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, |
| page_size=self.model_runner.server_args.page_size, |
| ) |
| tree_cache = RadixCache(cache_params) |
|
|
| batch = ScheduleBatch.init_new( |
| reqs=reqs, |
| req_to_token_pool=self.model_runner.req_to_token_pool, |
| token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, |
| tree_cache=tree_cache, |
| model_config=self.model_runner.model_config, |
| enable_overlap=False, |
| spec_algorithm=SpeculativeAlgorithm.NONE, |
| ) |
| batch.prepare_for_extend() |
|
|
| if require_mlp_sync(self.model_runner.server_args): |
| Scheduler.prepare_mlp_sync_batch_raw( |
| batch, |
| dp_size=self.model_runner.server_args.dp_size, |
| attn_tp_size=1, |
| tp_group=self.model_runner.tp_group, |
| get_idle_batch=None, |
| disable_cuda_graph=self.model_runner.server_args.disable_cuda_graph, |
| spec_algorithm=SpeculativeAlgorithm.NONE, |
| speculative_num_draft_tokens=None, |
| require_mlp_tp_gather=require_mlp_tp_gather( |
| self.model_runner.server_args |
| ), |
| disable_overlap_schedule=self.model_runner.server_args.disable_overlap_schedule, |
| offload_tags=set(), |
| ) |
|
|
| model_worker_batch = batch.get_model_worker_batch() |
| forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) |
| forward_batch.capture_hidden_mode = CaptureHiddenMode.FULL |
|
|
| output, _ = self.model_runner.forward(forward_batch) |
|
|
| |
| |
| |
|
|
| input_lens = [len(req.origin_input_ids) for req in reqs] |
|
|
| |
| if ( |
| hasattr(output, "aux_hidden_states") |
| and output.aux_hidden_states is not None |
| ): |
| hidden_states_list = torch.split( |
| output.aux_hidden_states, input_lens, dim=0 |
| ) |
| elif hasattr(output, "hidden_states") and output.hidden_states is not None: |
| hidden_states_list = torch.split(output.hidden_states, input_lens, dim=0) |
| else: |
| raise ValueError("SGLang output does not contain hidden states.") |
|
|
| self.model_runner.req_to_token_pool.clear() |
| self.model_runner.token_to_kv_pool_allocator.clear() |
|
|
| return hidden_states_list |
|
|
| @torch.no_grad() |
| def generate_dflash_data( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| ) -> DFlashTargetOutput: |
| sampling_params = SamplingParams(temperature=0, max_new_tokens=1) |
| reqs, data_cache = [], [] |
|
|
| if isinstance(input_ids, torch.Tensor): |
| input_ids_list = torch.split(input_ids, 1, dim=0) |
| attn_mask_list = torch.split(attention_mask, 1, dim=0) |
| loss_mask_list = torch.split(loss_mask, 1, dim=0) |
|
|
| for idx, (curr_ids, curr_attn, curr_loss) in enumerate( |
| zip(input_ids_list, attn_mask_list, loss_mask_list) |
| ): |
| req = Req( |
| rid=str(idx), |
| origin_input_text="", |
| origin_input_ids=curr_ids.view(-1).tolist(), |
| sampling_params=sampling_params, |
| ) |
| req.fill_ids = req.origin_input_ids |
| req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) |
| data_cache.append((curr_ids, curr_attn, curr_loss)) |
| reqs.append(req) |
|
|
| hidden_states_list = self._extend(reqs) |
|
|
| |
| hidden_states = torch.cat([h.unsqueeze(0) for h in hidden_states_list], dim=0) |
| input_ids = torch.cat([d[0] for d in data_cache], dim=0) |
| attention_mask = torch.cat([d[1] for d in data_cache], dim=0) |
| loss_mask = torch.cat([d[2] for d in data_cache], dim=0) |
|
|
| |
| hidden_states = padding(hidden_states, left=False) |
| input_ids = padding(input_ids, left=False) |
|
|
| return DFlashTargetOutput( |
| hidden_states=hidden_states, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| loss_mask=loss_mask, |
| ) |
|
|
|
|
| class HFDFlashTargetModel(DFlashTargetModel): |
| def __init__(self, model: nn.Module): |
| super().__init__() |
| self.model = model |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: str, |
| torch_dtype: torch.dtype = None, |
| device: str = None, |
| cache_dir: Optional[str] = None, |
| trust_remote_code: bool = True, |
| **kwargs, |
| ) -> "HFDFlashTargetModel": |
| tp_size = get_tp_group().size() |
|
|
| if tp_size > 1: |
| device_kwargs = { |
| "tp_plan": "auto", |
| "tp_size": tp_size, |
| "device_mesh": get_tp_device_mesh(), |
| } |
| else: |
| device_kwargs = { |
| "device_map": device, |
| } |
|
|
| target_model = AutoModelForCausalLM.from_pretrained( |
| pretrained_model_name_or_path, |
| torch_dtype=torch_dtype, |
| cache_dir=cache_dir, |
| output_hidden_states=True, |
| trust_remote_code=trust_remote_code, |
| attn_implementation="flash_attention_2", |
| **device_kwargs, |
| **kwargs, |
| ).eval() |
|
|
| return cls(target_model) |
|
|
| @torch.no_grad() |
| def generate_dflash_data( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| ) -> DFlashTargetOutput: |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| use_cache=False, |
| ) |
|
|
| |
| |
| |
| |
| |
|
|
| offset = 1 |
| selected = [] |
| if self.capture_layer_ids is not None: |
| for idx in self.capture_layer_ids: |
| selected.append(outputs.hidden_states[idx + offset]) |
| hidden_states = torch.cat(selected, dim=-1) |
| else: |
| |
| hidden_states = outputs.hidden_states[-1] |
|
|
| return DFlashTargetOutput( |
| hidden_states=hidden_states, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| loss_mask=loss_mask, |
| ) |
|
|
|
|
| def get_dflash_target_model( |
| pretrained_model_name_or_path: str, |
| backend: str = "sglang", |
| torch_dtype: torch.dtype = None, |
| device: str = None, |
| cache_dir: Optional[str] = None, |
| **kwargs, |
| ) -> DFlashTargetModel: |
| if backend == "sglang": |
| return SGLangDFlashTargetModel.from_pretrained( |
| pretrained_model_name_or_path=pretrained_model_name_or_path, |
| torch_dtype=torch_dtype, |
| device=device, |
| cache_dir=cache_dir, |
| **kwargs, |
| ) |
| elif backend == "hf": |
| return HFDFlashTargetModel.from_pretrained( |
| pretrained_model_name_or_path=pretrained_model_name_or_path, |
| torch_dtype=torch_dtype, |
| device=device, |
| cache_dir=cache_dir, |
| **kwargs, |
| ) |
| else: |
| raise ValueError(f"Invalid backend: {backend}") |
|
|