| | """ |
| | AuriStream Parallel Configuration for HuggingFace Transformers. |
| | """ |
| |
|
| | from transformers import PretrainedConfig |
| |
|
| |
|
| | class AuriStreamParallelConfig(PretrainedConfig): |
| | """Configuration class for AuriStream Parallel models.""" |
| |
|
| | model_type = "AuriStreamParallel" |
| |
|
| | def __init__( |
| | self, |
| | vocab_size: int = 8193, |
| | base_vocab_size: int = 8192, |
| | mask_token_id: int = 8192, |
| | ignore_index: int = -100, |
| | n_embd: int = 768, |
| | n_layer: int = 12, |
| | n_head: int = 12, |
| | dropout: float = 0.0, |
| | bias: bool = False, |
| | rope_theta: float = 10000.0, |
| | use_rope: bool = True, |
| | group_size: int = 4, |
| | seq_len: int = 4096, |
| | skip_connections: bool = False, |
| | mask_schedule: str = "linear_text_prime", |
| | **kwargs, |
| | ): |
| | self.vocab_size = vocab_size |
| | self.base_vocab_size = base_vocab_size |
| | self.mask_token_id = mask_token_id |
| | self.ignore_index = ignore_index |
| | self.n_embd = n_embd |
| | self.n_layer = n_layer |
| | self.n_head = n_head |
| | self.dropout = dropout |
| | self.bias = bias |
| | self.rope_theta = rope_theta |
| | self.use_rope = use_rope |
| | self.group_size = group_size |
| | self.seq_len = seq_len |
| | self.skip_connections = skip_connections |
| | self.mask_schedule = mask_schedule |
| |
|
| | super().__init__(**kwargs) |
| |
|
| | @classmethod |
| | def from_local_config(cls, local_cfg): |
| | """Create AuriStreamParallelConfig from local dataclass config.""" |
| | config_dict = {} |
| | known_attrs = [ |
| | "vocab_size", "base_vocab_size", "mask_token_id", "ignore_index", |
| | "n_embd", "n_layer", "n_head", "dropout", "bias", "rope_theta", |
| | "use_rope", "group_size", "seq_len", "skip_connections", "mask_schedule", |
| | ] |
| | for attr in known_attrs: |
| | if hasattr(local_cfg, attr): |
| | config_dict[attr] = getattr(local_cfg, attr) |
| | return cls(**config_dict) |
| |
|