""" AuriStream Parallel Configuration for HuggingFace Transformers. """ from transformers import PretrainedConfig class AuriStreamParallelConfig(PretrainedConfig): """Configuration class for AuriStream Parallel models.""" model_type = "AuriStreamParallel" def __init__( self, vocab_size: int = 8193, base_vocab_size: int = 8192, mask_token_id: int = 8192, ignore_index: int = -100, n_embd: int = 768, n_layer: int = 12, n_head: int = 12, dropout: float = 0.0, bias: bool = False, rope_theta: float = 10000.0, use_rope: bool = True, group_size: int = 4, seq_len: int = 4096, skip_connections: bool = False, mask_schedule: str = "linear_text_prime", **kwargs, ): self.vocab_size = vocab_size self.base_vocab_size = base_vocab_size self.mask_token_id = mask_token_id self.ignore_index = ignore_index self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head self.dropout = dropout self.bias = bias self.rope_theta = rope_theta self.use_rope = use_rope self.group_size = group_size self.seq_len = seq_len self.skip_connections = skip_connections self.mask_schedule = mask_schedule super().__init__(**kwargs) @classmethod def from_local_config(cls, local_cfg): """Create AuriStreamParallelConfig from local dataclass config.""" config_dict = {} known_attrs = [ "vocab_size", "base_vocab_size", "mask_token_id", "ignore_index", "n_embd", "n_layer", "n_head", "dropout", "bias", "rope_theta", "use_rope", "group_size", "seq_len", "skip_connections", "mask_schedule", ] for attr in known_attrs: if hasattr(local_cfg, attr): config_dict[attr] = getattr(local_cfg, attr) return cls(**config_dict)