File size: 5,206 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | import json
import os
from typing import Optional, Union
import torch
from transformers import AutoConfig
from transformers import AutoModelForCausalLM as AutoModelForCausalLMBase
from transformers import (
GptOssConfig,
Llama4Config,
Llama4TextConfig,
LlamaConfig,
Phi3Config,
PretrainedConfig,
Qwen2Config,
Qwen3Config,
Qwen3MoeConfig,
modeling_utils,
)
from .draft.llama3_eagle import LlamaForCausalLMEagle3
from .target.custom_backend import (
GptOssForCausalLM,
Llama4ForCausalLM,
LlamaForCausalLM,
Phi3ForCausalLM,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
Qwen3MoeForCausalLM,
)
class AutoEagle3DraftModel(AutoModelForCausalLMBase):
# the model mapping is currently hardcoded, we should support lazy model mapping via registry
_model_mapping = {
LlamaConfig: LlamaForCausalLMEagle3,
}
@classmethod
def from_config(cls, config: PretrainedConfig, torch_dtype=None, **config_kwargs):
"""
This class method takes a configuration object and create its model based on the
_model_mapping class variable.
Args:
config (PretrainedConfig): A configuration object.
Returns:
A model instance.
"""
# get the model class from the
_model_cls = cls._model_mapping[type(config)]
model = _model_cls(config, **config_kwargs)
# Convert model to specified dtype if provided
if torch_dtype is not None:
model = model.to(dtype=torch_dtype)
return model
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
*model_args,
**kwargs,
):
original_warn = modeling_utils.logger.warning
def filtered_warning(msg):
if "embed_tokens.weight" in str(msg) and "initialized" in str(msg):
return
original_warn(msg)
modeling_utils.logger.warning = filtered_warning
try:
model = super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
finally:
modeling_utils.logger.warning = original_warn
return model
class AutoDistributedTargetModel(AutoModelForCausalLMBase):
# the model mapping is currently hardcoded, we should support lazy model mapping via registry
_model_mapping = {
Llama4TextConfig: [Llama4ForCausalLM],
Qwen3MoeConfig: [Qwen3MoeForCausalLM],
Qwen2Config: [Qwen2ForCausalLM],
LlamaConfig: [LlamaForCausalLM],
Qwen3Config: [Qwen3ForCausalLM],
Phi3Config: [Phi3ForCausalLM],
GptOssConfig: [GptOssForCausalLM],
}
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
**config_kwargs,
):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
)
if isinstance(config, Llama4Config):
config = config.text_config
assert (
type(config) in cls._model_mapping
), f"Unsupported config type: {type(config)}"
model_cls = cls._model_mapping[type(config)][0]
model = model_cls.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
**config_kwargs,
)
if device is not None:
model = model.to(device)
else:
model = model.cuda()
return model
class AutoDraftModelConfig:
_config_mapping = {
"LlamaForCausalLMEagle3": LlamaConfig,
}
@classmethod
def from_file(cls, config_path: str):
"""
This class method takes a configuration file path and create its configuration object based on the
_config_mapping class variable.
Args:
config_path (str): A path to a configuration file.
Returns:
A configuration object.
"""
with open(config_path, "r") as f:
config = json.load(f)
if "tie_word_embeddings" in config:
print("Set draft model tie_word_embeddings to False")
config["tie_word_embeddings"] = False
# check for architectures
architectures = config.get("architectures", None)
if architectures is None:
raise ValueError("No architectures found in the config file")
if len(architectures) != 1:
raise ValueError("Only one architecture is supported")
architecture = architectures[0]
if architecture not in cls._config_mapping:
raise ValueError(f"Architecture {architecture} not supported")
# If draft_vocab_size is not in config or is None, set draft_vocab_size to vocab_size
if "draft_vocab_size" not in config or config["draft_vocab_size"] is None:
config["draft_vocab_size"] = config.get("vocab_size", None)
return cls._config_mapping[architecture].from_dict(config)
|