AuriStreamParallel-base / modeling_auristream_parallel.py
klemenk's picture
Upload AuriStream Parallel base model code
6211130 verified
"""
AuriStream Parallel model for HuggingFace Transformers.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from .configuration_auristream_parallel import AuriStreamParallelConfig
class RMSNorm(nn.Module):
def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) if weight else None
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
out = self._norm(x.float()).type_as(x)
return out * self.weight if self.weight is not None else out
class Rotary(nn.Module):
def __init__(self, dim: int, base: float = 10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x):
seq_len = x.shape[1]
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq).to(x.device)
return freqs.cos()[None, :, None, :], freqs.sin()[None, :, None, :]
def apply_rotary_emb(x, cos, sin):
d = x.shape[3] // 2
x1 = x[..., :d]
x2 = x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], dim=3)
class BidirectionalSelfAttention(nn.Module):
def __init__(self, config: AuriStreamParallelConfig):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.attn_dropout = nn.Dropout(config.dropout)
self.rotary = None
if getattr(config, "use_rope", True):
rope_theta = getattr(config, "rope_theta", 10000.0) or 10000.0
self.rotary = Rotary(self.head_dim, base=rope_theta)
def forward(self, x):
bsz, tsz, channels = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
q = q.view(bsz, tsz, self.n_head, self.head_dim)
k = k.view(bsz, tsz, self.n_head, self.head_dim)
v = v.view(bsz, tsz, self.n_head, self.head_dim)
if self.rotary is not None:
cos, sin = self.rotary(q)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
y = F.scaled_dot_product_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
is_causal=False,
)
y = y.transpose(1, 2).contiguous().view(bsz, tsz, channels)
return self.c_proj(y)
class MLP(nn.Module):
def __init__(self, config: AuriStreamParallelConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.act = nn.SiLU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.act(x)
x = self.c_proj(x)
return self.dropout(x)
class Block(nn.Module):
def __init__(self, config: AuriStreamParallelConfig):
super().__init__()
self.attn = BidirectionalSelfAttention(config)
self.mlp = MLP(config)
self.norm1 = RMSNorm(config.n_embd, bias=config.bias)
self.norm2 = RMSNorm(config.n_embd, bias=config.bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class AuriStreamPreTrainedModel(PreTrainedModel):
config_class = AuriStreamParallelConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Block"]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
class AuriStreamModel(AuriStreamPreTrainedModel):
"""HF-compatible AuriStream Parallel model."""
config_class = AuriStreamParallelConfig
def __init__(self, config: AuriStreamParallelConfig):
super().__init__(config)
self.config = config
self.group_size = int(getattr(config, "group_size", 4))
grouped_seq_len = max(1, config.seq_len // self.group_size)
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = None
if not getattr(config, "use_rope", True):
self.wpe = nn.Embedding(grouped_seq_len, config.n_embd)
self.drop = nn.Dropout(config.dropout)
self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
self.ln_f = RMSNorm(config.n_embd, bias=config.bias)
self.group_in_proj = nn.Linear(self.group_size * config.n_embd, config.n_embd, bias=False)
self.parallel_heads = nn.ModuleList(
[nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(self.group_size)]
)
self.apply(self._init_weights)
for name, param in self.named_parameters():
if name.endswith("c_proj.weight"):
torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, value):
self.wte = value
def _group_embed(self, input_ids: torch.LongTensor) -> torch.Tensor:
bsz, tsz = input_ids.shape
if tsz % self.group_size != 0:
raise ValueError(
f"Sequence length {tsz} must be divisible by group_size={self.group_size}"
)
tok_emb = self.wte(input_ids)
grouped = tok_emb.view(bsz, tsz // self.group_size, self.group_size, self.config.n_embd)
grouped = grouped.reshape(bsz, tsz // self.group_size, self.group_size * self.config.n_embd)
x = self.group_in_proj(grouped)
if self.wpe is not None:
pos = torch.arange(x.size(1), device=input_ids.device)
x = x + self.wpe(pos)
return self.drop(x)
def _decode_parallel_logits(self, x: torch.Tensor) -> torch.Tensor:
per_head = [head(x) for head in self.parallel_heads]
logits = torch.stack(per_head, dim=2) # (B, T_g, G, V)
bsz, tg, gsz, vsz = logits.shape
return logits.reshape(bsz, tg * gsz, vsz)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
seq: Optional[torch.LongTensor] = None,
tgt: Optional[torch.LongTensor] = None,
):
if seq is not None:
input_ids = seq
if tgt is not None:
labels = tgt
if input_ids is None:
raise ValueError("input_ids (or seq) must be provided")
usable_len = (input_ids.shape[1] // self.group_size) * self.group_size
if usable_len <= 0:
raise ValueError(
f"Input sequence length {input_ids.shape[1]} is too short for group_size={self.group_size}"
)
if usable_len != input_ids.shape[1]:
input_ids = input_ids[:, :usable_len]
if labels is not None:
labels = labels[:, :usable_len]
x = self._group_embed(input_ids)
all_hidden_states = ()
if output_hidden_states:
all_hidden_states = (x,)
for block in self.h:
x = block(x)
if output_hidden_states:
all_hidden_states = all_hidden_states + (x,)
x = self.ln_f(x)
logits = self._decode_parallel_logits(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.reshape(-1, self.config.vocab_size),
labels.reshape(-1),
ignore_index=getattr(self.config, "ignore_index", -100),
)
if not return_dict:
out = (logits,)
if output_hidden_states:
out = out + (all_hidden_states,)
return ((loss,) + out) if loss is not None else out
return CausalLMOutput(
loss=loss,
logits=logits,
hidden_states=all_hidden_states if output_hidden_states else None,
attentions=None,
)