Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import json
import os
from typing import Optional, Union
import torch
from transformers import AutoConfig
from transformers import AutoModelForCausalLM as AutoModelForCausalLMBase
from transformers import (
GptOssConfig,
Llama4Config,
Llama4TextConfig,
LlamaConfig,
Phi3Config,
PretrainedConfig,
Qwen2Config,
Qwen3Config,
Qwen3MoeConfig,
modeling_utils,
)
from .draft.llama3_eagle import LlamaForCausalLMEagle3
from .target.custom_backend import (
GptOssForCausalLM,
Llama4ForCausalLM,
LlamaForCausalLM,
Phi3ForCausalLM,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
Qwen3MoeForCausalLM,
)
class AutoEagle3DraftModel(AutoModelForCausalLMBase):
# the model mapping is currently hardcoded, we should support lazy model mapping via registry
_model_mapping = {
LlamaConfig: LlamaForCausalLMEagle3,
}
@classmethod
def from_config(cls, config: PretrainedConfig, torch_dtype=None, **config_kwargs):
"""
This class method takes a configuration object and create its model based on the
_model_mapping class variable.
Args:
config (PretrainedConfig): A configuration object.
Returns:
A model instance.
"""
# get the model class from the
_model_cls = cls._model_mapping[type(config)]
model = _model_cls(config, **config_kwargs)
# Convert model to specified dtype if provided
if torch_dtype is not None:
model = model.to(dtype=torch_dtype)
return model
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
*model_args,
**kwargs,
):
original_warn = modeling_utils.logger.warning
def filtered_warning(msg):
if "embed_tokens.weight" in str(msg) and "initialized" in str(msg):
return
original_warn(msg)
modeling_utils.logger.warning = filtered_warning
try:
model = super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
finally:
modeling_utils.logger.warning = original_warn
return model
class AutoDistributedTargetModel(AutoModelForCausalLMBase):
# the model mapping is currently hardcoded, we should support lazy model mapping via registry
_model_mapping = {
Llama4TextConfig: [Llama4ForCausalLM],
Qwen3MoeConfig: [Qwen3MoeForCausalLM],
Qwen2Config: [Qwen2ForCausalLM],
LlamaConfig: [LlamaForCausalLM],
Qwen3Config: [Qwen3ForCausalLM],
Phi3Config: [Phi3ForCausalLM],
GptOssConfig: [GptOssForCausalLM],
}
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
**config_kwargs,
):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
)
if isinstance(config, Llama4Config):
config = config.text_config
assert (
type(config) in cls._model_mapping
), f"Unsupported config type: {type(config)}"
model_cls = cls._model_mapping[type(config)][0]
model = model_cls.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
**config_kwargs,
)
if device is not None:
model = model.to(device)
else:
model = model.cuda()
return model
class AutoDraftModelConfig:
_config_mapping = {
"LlamaForCausalLMEagle3": LlamaConfig,
}
@classmethod
def from_file(cls, config_path: str):
"""
This class method takes a configuration file path and create its configuration object based on the
_config_mapping class variable.
Args:
config_path (str): A path to a configuration file.
Returns:
A configuration object.
"""
with open(config_path, "r") as f:
config = json.load(f)
if "tie_word_embeddings" in config:
print("Set draft model tie_word_embeddings to False")
config["tie_word_embeddings"] = False
# check for architectures
architectures = config.get("architectures", None)
if architectures is None:
raise ValueError("No architectures found in the config file")
if len(architectures) != 1:
raise ValueError("Only one architecture is supported")
architecture = architectures[0]
if architecture not in cls._config_mapping:
raise ValueError(f"Architecture {architecture} not supported")
# If draft_vocab_size is not in config or is None, set draft_vocab_size to vocab_size
if "draft_vocab_size" not in config or config["draft_vocab_size"] is None:
config["draft_vocab_size"] = config.get("vocab_size", None)
return cls._config_mapping[architecture].from_dict(config)