""" AuriStream Parallel model for HuggingFace Transformers. """ import math from typing import Optional, Tuple import torch import torch.nn as nn from torch.nn import functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutput from .configuration_auristream_parallel import AuriStreamParallelConfig class RMSNorm(nn.Module): def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) if weight else None def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): out = self._norm(x.float()).type_as(x) return out * self.weight if self.weight is not None else out class Rotary(nn.Module): def __init__(self, dim: int, base: float = 10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, x): seq_len = x.shape[1] t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq).to(x.device) return freqs.cos()[None, :, None, :], freqs.sin()[None, :, None, :] def apply_rotary_emb(x, cos, sin): d = x.shape[3] // 2 x1 = x[..., :d] x2 = x[..., d:] y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat([y1, y2], dim=3) class BidirectionalSelfAttention(nn.Module): def __init__(self, config: AuriStreamParallelConfig): super().__init__() self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = self.n_embd // self.n_head assert self.n_embd % self.n_head == 0 self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.attn_dropout = nn.Dropout(config.dropout) self.rotary = None if getattr(config, "use_rope", True): rope_theta = getattr(config, "rope_theta", 10000.0) or 10000.0 self.rotary = Rotary(self.head_dim, base=rope_theta) def forward(self, x): bsz, tsz, channels = x.size() qkv = self.c_attn(x) q, k, v = qkv.split(self.n_embd, dim=2) q = q.view(bsz, tsz, self.n_head, self.head_dim) k = k.view(bsz, tsz, self.n_head, self.head_dim) v = v.view(bsz, tsz, self.n_head, self.head_dim) if self.rotary is not None: cos, sin = self.rotary(q) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) y = F.scaled_dot_product_attention( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=False, ) y = y.transpose(1, 2).contiguous().view(bsz, tsz, channels) return self.c_proj(y) class MLP(nn.Module): def __init__(self, config: AuriStreamParallelConfig): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) self.act = nn.SiLU() self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): x = self.c_fc(x) x = self.act(x) x = self.c_proj(x) return self.dropout(x) class Block(nn.Module): def __init__(self, config: AuriStreamParallelConfig): super().__init__() self.attn = BidirectionalSelfAttention(config) self.mlp = MLP(config) self.norm1 = RMSNorm(config.n_embd, bias=config.bias) self.norm2 = RMSNorm(config.n_embd, bias=config.bias) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x class AuriStreamPreTrainedModel(PreTrainedModel): config_class = AuriStreamParallelConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Block"] def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) class AuriStreamModel(AuriStreamPreTrainedModel): """HF-compatible AuriStream Parallel model.""" config_class = AuriStreamParallelConfig def __init__(self, config: AuriStreamParallelConfig): super().__init__(config) self.config = config self.group_size = int(getattr(config, "group_size", 4)) grouped_seq_len = max(1, config.seq_len // self.group_size) self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = None if not getattr(config, "use_rope", True): self.wpe = nn.Embedding(grouped_seq_len, config.n_embd) self.drop = nn.Dropout(config.dropout) self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) self.ln_f = RMSNorm(config.n_embd, bias=config.bias) self.group_in_proj = nn.Linear(self.group_size * config.n_embd, config.n_embd, bias=False) self.parallel_heads = nn.ModuleList( [nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(self.group_size)] ) self.apply(self._init_weights) for name, param in self.named_parameters(): if name.endswith("c_proj.weight"): torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) def get_input_embeddings(self): return self.wte def set_input_embeddings(self, value): self.wte = value def _group_embed(self, input_ids: torch.LongTensor) -> torch.Tensor: bsz, tsz = input_ids.shape if tsz % self.group_size != 0: raise ValueError( f"Sequence length {tsz} must be divisible by group_size={self.group_size}" ) tok_emb = self.wte(input_ids) grouped = tok_emb.view(bsz, tsz // self.group_size, self.group_size, self.config.n_embd) grouped = grouped.reshape(bsz, tsz // self.group_size, self.group_size * self.config.n_embd) x = self.group_in_proj(grouped) if self.wpe is not None: pos = torch.arange(x.size(1), device=input_ids.device) x = x + self.wpe(pos) return self.drop(x) def _decode_parallel_logits(self, x: torch.Tensor) -> torch.Tensor: per_head = [head(x) for head in self.parallel_heads] logits = torch.stack(per_head, dim=2) # (B, T_g, G, V) bsz, tg, gsz, vsz = logits.shape return logits.reshape(bsz, tg * gsz, vsz) def forward( self, input_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, seq: Optional[torch.LongTensor] = None, tgt: Optional[torch.LongTensor] = None, ): if seq is not None: input_ids = seq if tgt is not None: labels = tgt if input_ids is None: raise ValueError("input_ids (or seq) must be provided") usable_len = (input_ids.shape[1] // self.group_size) * self.group_size if usable_len <= 0: raise ValueError( f"Input sequence length {input_ids.shape[1]} is too short for group_size={self.group_size}" ) if usable_len != input_ids.shape[1]: input_ids = input_ids[:, :usable_len] if labels is not None: labels = labels[:, :usable_len] x = self._group_embed(input_ids) all_hidden_states = () if output_hidden_states: all_hidden_states = (x,) for block in self.h: x = block(x) if output_hidden_states: all_hidden_states = all_hidden_states + (x,) x = self.ln_f(x) logits = self._decode_parallel_logits(x) loss = None if labels is not None: loss = F.cross_entropy( logits.reshape(-1, self.config.vocab_size), labels.reshape(-1), ignore_index=getattr(self.config, "ignore_index", -100), ) if not return_dict: out = (logits,) if output_hidden_states: out = out + (all_hidden_states,) return ((loss,) + out) if loss is not None else out return CausalLMOutput( loss=loss, logits=logits, hidden_states=all_hidden_states if output_hidden_states else None, attentions=None, )