| | """ |
| | AuriStream Parallel model for HuggingFace Transformers. |
| | """ |
| |
|
| | import math |
| | from typing import Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| |
|
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutput |
| |
|
| | from .configuration_auristream_parallel import AuriStreamParallelConfig |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6): |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(dim)) if weight else None |
| |
|
| | def _norm(self, x): |
| | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| |
|
| | def forward(self, x): |
| | out = self._norm(x.float()).type_as(x) |
| | return out * self.weight if self.weight is not None else out |
| |
|
| |
|
| | class Rotary(nn.Module): |
| | def __init__(self, dim: int, base: float = 10000): |
| | super().__init__() |
| | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer("inv_freq", inv_freq) |
| |
|
| | def forward(self, x): |
| | seq_len = x.shape[1] |
| | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
| | freqs = torch.outer(t, self.inv_freq).to(x.device) |
| | return freqs.cos()[None, :, None, :], freqs.sin()[None, :, None, :] |
| |
|
| |
|
| | def apply_rotary_emb(x, cos, sin): |
| | d = x.shape[3] // 2 |
| | x1 = x[..., :d] |
| | x2 = x[..., d:] |
| | y1 = x1 * cos + x2 * sin |
| | y2 = x1 * (-sin) + x2 * cos |
| | return torch.cat([y1, y2], dim=3) |
| |
|
| |
|
| | class BidirectionalSelfAttention(nn.Module): |
| | def __init__(self, config: AuriStreamParallelConfig): |
| | super().__init__() |
| | self.n_head = config.n_head |
| | self.n_embd = config.n_embd |
| | self.head_dim = self.n_embd // self.n_head |
| | assert self.n_embd % self.n_head == 0 |
| |
|
| | self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False) |
| | self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) |
| | self.attn_dropout = nn.Dropout(config.dropout) |
| |
|
| | self.rotary = None |
| | if getattr(config, "use_rope", True): |
| | rope_theta = getattr(config, "rope_theta", 10000.0) or 10000.0 |
| | self.rotary = Rotary(self.head_dim, base=rope_theta) |
| |
|
| | def forward(self, x): |
| | bsz, tsz, channels = x.size() |
| |
|
| | qkv = self.c_attn(x) |
| | q, k, v = qkv.split(self.n_embd, dim=2) |
| | q = q.view(bsz, tsz, self.n_head, self.head_dim) |
| | k = k.view(bsz, tsz, self.n_head, self.head_dim) |
| | v = v.view(bsz, tsz, self.n_head, self.head_dim) |
| |
|
| | if self.rotary is not None: |
| | cos, sin = self.rotary(q) |
| | q = apply_rotary_emb(q, cos, sin) |
| | k = apply_rotary_emb(k, cos, sin) |
| |
|
| | y = F.scaled_dot_product_attention( |
| | q.transpose(1, 2), |
| | k.transpose(1, 2), |
| | v.transpose(1, 2), |
| | is_causal=False, |
| | ) |
| |
|
| | y = y.transpose(1, 2).contiguous().view(bsz, tsz, channels) |
| | return self.c_proj(y) |
| |
|
| |
|
| | class MLP(nn.Module): |
| | def __init__(self, config: AuriStreamParallelConfig): |
| | super().__init__() |
| | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
| | self.act = nn.SiLU() |
| | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
| | self.dropout = nn.Dropout(config.dropout) |
| |
|
| | def forward(self, x): |
| | x = self.c_fc(x) |
| | x = self.act(x) |
| | x = self.c_proj(x) |
| | return self.dropout(x) |
| |
|
| |
|
| | class Block(nn.Module): |
| | def __init__(self, config: AuriStreamParallelConfig): |
| | super().__init__() |
| | self.attn = BidirectionalSelfAttention(config) |
| | self.mlp = MLP(config) |
| | self.norm1 = RMSNorm(config.n_embd, bias=config.bias) |
| | self.norm2 = RMSNorm(config.n_embd, bias=config.bias) |
| |
|
| | def forward(self, x): |
| | x = x + self.attn(self.norm1(x)) |
| | x = x + self.mlp(self.norm2(x)) |
| | return x |
| |
|
| |
|
| | class AuriStreamPreTrainedModel(PreTrainedModel): |
| | config_class = AuriStreamParallelConfig |
| | base_model_prefix = "model" |
| | supports_gradient_checkpointing = True |
| | _no_split_modules = ["Block"] |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
|
| |
|
| | class AuriStreamModel(AuriStreamPreTrainedModel): |
| | """HF-compatible AuriStream Parallel model.""" |
| |
|
| | config_class = AuriStreamParallelConfig |
| |
|
| | def __init__(self, config: AuriStreamParallelConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | self.group_size = int(getattr(config, "group_size", 4)) |
| | grouped_seq_len = max(1, config.seq_len // self.group_size) |
| |
|
| | self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
| | self.wpe = None |
| | if not getattr(config, "use_rope", True): |
| | self.wpe = nn.Embedding(grouped_seq_len, config.n_embd) |
| |
|
| | self.drop = nn.Dropout(config.dropout) |
| | self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) |
| | self.ln_f = RMSNorm(config.n_embd, bias=config.bias) |
| |
|
| | self.group_in_proj = nn.Linear(self.group_size * config.n_embd, config.n_embd, bias=False) |
| | self.parallel_heads = nn.ModuleList( |
| | [nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(self.group_size)] |
| | ) |
| |
|
| | self.apply(self._init_weights) |
| | for name, param in self.named_parameters(): |
| | if name.endswith("c_proj.weight"): |
| | torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) |
| |
|
| | def get_input_embeddings(self): |
| | return self.wte |
| |
|
| | def set_input_embeddings(self, value): |
| | self.wte = value |
| |
|
| | def _group_embed(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| | bsz, tsz = input_ids.shape |
| | if tsz % self.group_size != 0: |
| | raise ValueError( |
| | f"Sequence length {tsz} must be divisible by group_size={self.group_size}" |
| | ) |
| |
|
| | tok_emb = self.wte(input_ids) |
| | grouped = tok_emb.view(bsz, tsz // self.group_size, self.group_size, self.config.n_embd) |
| | grouped = grouped.reshape(bsz, tsz // self.group_size, self.group_size * self.config.n_embd) |
| | x = self.group_in_proj(grouped) |
| |
|
| | if self.wpe is not None: |
| | pos = torch.arange(x.size(1), device=input_ids.device) |
| | x = x + self.wpe(pos) |
| |
|
| | return self.drop(x) |
| |
|
| | def _decode_parallel_logits(self, x: torch.Tensor) -> torch.Tensor: |
| | per_head = [head(x) for head in self.parallel_heads] |
| | logits = torch.stack(per_head, dim=2) |
| | bsz, tg, gsz, vsz = logits.shape |
| | return logits.reshape(bsz, tg * gsz, vsz) |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | output_hidden_states: Optional[bool] = False, |
| | return_dict: Optional[bool] = True, |
| | seq: Optional[torch.LongTensor] = None, |
| | tgt: Optional[torch.LongTensor] = None, |
| | ): |
| | if seq is not None: |
| | input_ids = seq |
| | if tgt is not None: |
| | labels = tgt |
| | if input_ids is None: |
| | raise ValueError("input_ids (or seq) must be provided") |
| |
|
| | usable_len = (input_ids.shape[1] // self.group_size) * self.group_size |
| | if usable_len <= 0: |
| | raise ValueError( |
| | f"Input sequence length {input_ids.shape[1]} is too short for group_size={self.group_size}" |
| | ) |
| | if usable_len != input_ids.shape[1]: |
| | input_ids = input_ids[:, :usable_len] |
| | if labels is not None: |
| | labels = labels[:, :usable_len] |
| |
|
| | x = self._group_embed(input_ids) |
| |
|
| | all_hidden_states = () |
| | if output_hidden_states: |
| | all_hidden_states = (x,) |
| |
|
| | for block in self.h: |
| | x = block(x) |
| | if output_hidden_states: |
| | all_hidden_states = all_hidden_states + (x,) |
| |
|
| | x = self.ln_f(x) |
| | logits = self._decode_parallel_logits(x) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss = F.cross_entropy( |
| | logits.reshape(-1, self.config.vocab_size), |
| | labels.reshape(-1), |
| | ignore_index=getattr(self.config, "ignore_index", -100), |
| | ) |
| |
|
| | if not return_dict: |
| | out = (logits,) |
| | if output_hidden_states: |
| | out = out + (all_hidden_states,) |
| | return ((loss,) + out) if loss is not None else out |
| |
|
| | return CausalLMOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=all_hidden_states if output_hidden_states else None, |
| | attentions=None, |
| | ) |
| |
|