AuriStreamParallel-base / configuration_auristream_parallel.py
klemenk's picture
Upload AuriStream Parallel base model code
c07a579 verified
"""
AuriStream Parallel Configuration for HuggingFace Transformers.
"""
from transformers import PretrainedConfig
class AuriStreamParallelConfig(PretrainedConfig):
"""Configuration class for AuriStream Parallel models."""
model_type = "AuriStreamParallel"
def __init__(
self,
vocab_size: int = 8193,
base_vocab_size: int = 8192,
mask_token_id: int = 8192,
ignore_index: int = -100,
n_embd: int = 768,
n_layer: int = 12,
n_head: int = 12,
dropout: float = 0.0,
bias: bool = False,
rope_theta: float = 10000.0,
use_rope: bool = True,
group_size: int = 4,
seq_len: int = 4096,
skip_connections: bool = False,
mask_schedule: str = "linear_text_prime",
**kwargs,
):
self.vocab_size = vocab_size
self.base_vocab_size = base_vocab_size
self.mask_token_id = mask_token_id
self.ignore_index = ignore_index
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.dropout = dropout
self.bias = bias
self.rope_theta = rope_theta
self.use_rope = use_rope
self.group_size = group_size
self.seq_len = seq_len
self.skip_connections = skip_connections
self.mask_schedule = mask_schedule
super().__init__(**kwargs)
@classmethod
def from_local_config(cls, local_cfg):
"""Create AuriStreamParallelConfig from local dataclass config."""
config_dict = {}
known_attrs = [
"vocab_size", "base_vocab_size", "mask_token_id", "ignore_index",
"n_embd", "n_layer", "n_head", "dropout", "bias", "rope_theta",
"use_rope", "group_size", "seq_len", "skip_connections", "mask_schedule",
]
for attr in known_attrs:
if hasattr(local_cfg, attr):
config_dict[attr] = getattr(local_cfg, attr)
return cls(**config_dict)