| import json |
| import os |
| from typing import Optional, Union |
|
|
| import torch |
| from transformers import AutoConfig |
| from transformers import AutoModelForCausalLM as AutoModelForCausalLMBase |
| from transformers import ( |
| GptOssConfig, |
| Llama4Config, |
| Llama4TextConfig, |
| LlamaConfig, |
| Phi3Config, |
| PretrainedConfig, |
| Qwen2Config, |
| Qwen3Config, |
| Qwen3MoeConfig, |
| modeling_utils, |
| ) |
|
|
| from .draft.llama3_eagle import LlamaForCausalLMEagle3 |
| from .target.custom_backend import ( |
| GptOssForCausalLM, |
| Llama4ForCausalLM, |
| LlamaForCausalLM, |
| Phi3ForCausalLM, |
| Qwen2ForCausalLM, |
| Qwen3ForCausalLM, |
| Qwen3MoeForCausalLM, |
| ) |
|
|
|
|
| class AutoEagle3DraftModel(AutoModelForCausalLMBase): |
| |
| _model_mapping = { |
| LlamaConfig: LlamaForCausalLMEagle3, |
| } |
|
|
| @classmethod |
| def from_config(cls, config: PretrainedConfig, torch_dtype=None, **config_kwargs): |
| """ |
| This class method takes a configuration object and create its model based on the |
| _model_mapping class variable. |
| |
| Args: |
| config (PretrainedConfig): A configuration object. |
| |
| Returns: |
| A model instance. |
| """ |
| |
| _model_cls = cls._model_mapping[type(config)] |
| model = _model_cls(config, **config_kwargs) |
|
|
| |
| if torch_dtype is not None: |
| model = model.to(dtype=torch_dtype) |
| return model |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: Union[str, os.PathLike[str]], |
| *model_args, |
| **kwargs, |
| ): |
| original_warn = modeling_utils.logger.warning |
|
|
| def filtered_warning(msg): |
| if "embed_tokens.weight" in str(msg) and "initialized" in str(msg): |
| return |
| original_warn(msg) |
|
|
| modeling_utils.logger.warning = filtered_warning |
|
|
| try: |
| model = super().from_pretrained( |
| pretrained_model_name_or_path, *model_args, **kwargs |
| ) |
| finally: |
| modeling_utils.logger.warning = original_warn |
|
|
| return model |
|
|
|
|
| class AutoDistributedTargetModel(AutoModelForCausalLMBase): |
| |
| _model_mapping = { |
| Llama4TextConfig: [Llama4ForCausalLM], |
| Qwen3MoeConfig: [Qwen3MoeForCausalLM], |
| Qwen2Config: [Qwen2ForCausalLM], |
| LlamaConfig: [LlamaForCausalLM], |
| Qwen3Config: [Qwen3ForCausalLM], |
| Phi3Config: [Phi3ForCausalLM], |
| GptOssConfig: [GptOssForCausalLM], |
| } |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: Union[str, os.PathLike[str]], |
| torch_dtype: torch.dtype = None, |
| device: str = None, |
| cache_dir: Optional[str] = None, |
| **config_kwargs, |
| ): |
| config = AutoConfig.from_pretrained( |
| pretrained_model_name_or_path, |
| ) |
|
|
| if isinstance(config, Llama4Config): |
| config = config.text_config |
|
|
| assert ( |
| type(config) in cls._model_mapping |
| ), f"Unsupported config type: {type(config)}" |
| model_cls = cls._model_mapping[type(config)][0] |
| model = model_cls.from_pretrained( |
| pretrained_model_name_or_path, |
| torch_dtype=torch_dtype, |
| cache_dir=cache_dir, |
| **config_kwargs, |
| ) |
|
|
| if device is not None: |
| model = model.to(device) |
| else: |
| model = model.cuda() |
| return model |
|
|
|
|
| class AutoDraftModelConfig: |
|
|
| _config_mapping = { |
| "LlamaForCausalLMEagle3": LlamaConfig, |
| } |
|
|
| @classmethod |
| def from_file(cls, config_path: str): |
| """ |
| This class method takes a configuration file path and create its configuration object based on the |
| _config_mapping class variable. |
| |
| Args: |
| config_path (str): A path to a configuration file. |
| |
| Returns: |
| A configuration object. |
| """ |
| with open(config_path, "r") as f: |
| config = json.load(f) |
|
|
| if "tie_word_embeddings" in config: |
| print("Set draft model tie_word_embeddings to False") |
| config["tie_word_embeddings"] = False |
|
|
| |
| architectures = config.get("architectures", None) |
|
|
| if architectures is None: |
| raise ValueError("No architectures found in the config file") |
|
|
| if len(architectures) != 1: |
| raise ValueError("Only one architecture is supported") |
|
|
| architecture = architectures[0] |
|
|
| if architecture not in cls._config_mapping: |
| raise ValueError(f"Architecture {architecture} not supported") |
|
|
| |
| if "draft_vocab_size" not in config or config["draft_vocab_size"] is None: |
| config["draft_vocab_size"] = config.get("vocab_size", None) |
|
|
| return cls._config_mapping[architecture].from_dict(config) |
|
|