import math from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutput from .configuration_binaryllm import BinaryLLMConfig try: import flash_attn_v100_cuda _FLASH_V100_AVAILABLE = True except Exception: flash_attn_v100_cuda = None _FLASH_V100_AVAILABLE = False class PositionalEncoding(nn.Module): """ Sinusoidal positional encoding, stocké en fp32, puis casté au dtype de x à chaque forward. """ def __init__(self, d_model: int, max_len: int) -> None: super().__init__() pe = torch.zeros(max_len, d_model, dtype=torch.float32) position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2, dtype=torch.float32) * (-torch.log(torch.tensor(10000.0)) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pe", pe, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: t = x.size(1) pe = self.pe[:, :t, :].to(device=x.device, dtype=x.dtype) return x + pe @dataclass class _InnerCfg: block_size: int embed_dim: int vocab_size: int num_heads: int num_layers: int ff_hidden_dim: int dropout: float layernorm_dim: Optional[int] = None head_dim: Optional[int] = None attn_backend: str = "auto" class FlashSelfAttentionPortable(nn.Module): def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, causal: bool = True, backend: str = "auto", ) -> None: super().__init__() if embed_dim % num_heads != 0: raise ValueError( f"embed_dim ({embed_dim}) doit être divisible par num_heads ({num_heads})" ) self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.dropout = float(dropout) self.causal = bool(causal) self.backend = str(backend) self.softmax_scale = self.head_dim ** -0.5 self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) def _shape_qkv( self, x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.dtype]: bsz, seqlen, _ = x.shape residual_dtype = x.dtype proj_dtype = self.q_proj.weight.dtype if x.dtype != proj_dtype: x = x.to(proj_dtype) q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) q = q.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2).contiguous() k = k.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2).contiguous() v = v.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return q, k, v, residual_dtype def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: bsz, nheads, seqlen, head_dim = x.shape return x.transpose(1, 2).contiguous().view(bsz, seqlen, nheads * head_dim) def _can_use_v100_kernel(self, q: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> bool: if not _FLASH_V100_AVAILABLE: return False if not q.is_cuda: return False if padding_mask is not None and bool(padding_mask.any().item()): return False cc = torch.cuda.get_device_capability(q.device) if cc != (7, 0): return False hd = q.size(-1) if hd % 2 != 0: return False if hd % 8 != 0: return False if hd > 256: return False return True def _flash_attn_v100( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, ) -> torch.Tensor: if q.dtype != torch.float16: q = q.to(torch.float16) if k.dtype != torch.float16: k = k.to(torch.float16) if v.dtype != torch.float16: v = v.to(torch.float16) result = flash_attn_v100_cuda.fwd( q, k, v, None, None, 0.0, self.softmax_scale, self.causal, -1, -1, 0.0, False, None, ) out = result[0] return out def _sdpa_attn( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: bsz, nheads, tq, _ = q.shape tk = k.size(-2) attn_mask = None if padding_mask is not None: key_mask = padding_mask[:, None, None, :].to(device=q.device, dtype=torch.bool) key_mask = key_mask.expand(bsz, nheads, tq, tk) attn_mask = ~key_mask dropout_p = self.dropout if self.training else 0.0 with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_mem_efficient=True, enable_math=True, ): out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=self.causal if attn_mask is None else False, scale=self.softmax_scale, ) return out def _eager_attn( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * self.softmax_scale if self.causal: tq = q.size(-2) tk = k.size(-2) causal_mask = torch.triu( torch.ones(tq, tk, device=scores.device, dtype=torch.bool), diagonal=1, ) scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) if padding_mask is not None: key_mask = padding_mask[:, None, None, :].to(device=scores.device, dtype=torch.bool) scores = scores.masked_fill(key_mask, float("-inf")) probs = torch.softmax(scores, dim=-1) if self.training and self.dropout > 0.0: probs = F.dropout(probs, p=self.dropout) out = torch.matmul(probs, v.float()) return out.to(q.dtype) def forward( self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: q, k, v, residual_dtype = self._shape_qkv(x) if padding_mask is not None: padding_mask = padding_mask.to(device=x.device, dtype=torch.bool) backend = self.backend if backend == "v100": if not self._can_use_v100_kernel(q, padding_mask): raise RuntimeError( "backend='v100' demandé mais indisponible " "(flash_attn_v100_cuda absent, GPU non sm70/V100, padding présent, " "ou head_dim incompatible)." ) out = self._flash_attn_v100(q, k, v) elif backend == "sdpa": out = self._sdpa_attn(q, k, v, padding_mask=padding_mask) elif backend == "eager": out = self._eager_attn(q, k, v, padding_mask=padding_mask) elif backend == "auto": if self._can_use_v100_kernel(q, padding_mask): out = self._flash_attn_v100(q, k, v) else: out = self._sdpa_attn(q, k, v, padding_mask=padding_mask) else: raise ValueError(f"backend d'attention non supporté: {backend}") out = self._merge_heads(out) out_proj_dtype = self.out_proj.weight.dtype if out.dtype != out_proj_dtype: out = out.to(out_proj_dtype) out = self.out_proj(out) if out.dtype != residual_dtype: out = out.to(residual_dtype) return out class FlashTransformerEncoderLayerPortable(nn.Module): def __init__( self, d_model: int, nhead: int, dim_feedforward: int, dropout: float = 0.1, activation: str = "gelu", batch_first: bool = True, attn_backend: str = "auto", ) -> None: super().__init__() if not batch_first: raise ValueError("Cette implémentation supporte batch_first=True uniquement.") self.self_attn = FlashSelfAttentionPortable( embed_dim=d_model, num_heads=nhead, dropout=dropout, causal=True, backend=attn_backend, ) self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) if activation == "gelu": self.activation = F.gelu elif activation == "relu": self.activation = F.relu else: raise ValueError(f"activation non supportée: {activation}") def _sa_block( self, x: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor], ) -> torch.Tensor: x = self.self_attn(x, padding_mask=src_key_padding_mask) x = self.dropout1(x) return x def _ff_block(self, x: torch.Tensor) -> torch.Tensor: ff_dtype = self.linear1.weight.dtype x_ff = x if x.dtype == ff_dtype else x.to(ff_dtype) x_ff = self.linear1(x_ff) x_ff = self.activation(x_ff) x_ff = self.dropout(x_ff) x_ff = self.linear2(x_ff) x_ff = self.dropout2(x_ff) if x_ff.dtype != x.dtype: x_ff = x_ff.to(x.dtype) return x_ff def forward( self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = src x = self.norm1(x + self._sa_block(x, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x class FlashTransformerEncoderPortable(nn.Module): def __init__( self, encoder_layer: FlashTransformerEncoderLayerPortable, num_layers: int, attn_backend: str = "auto", ) -> None: super().__init__() d_model = encoder_layer.norm1.normalized_shape[0] nhead = encoder_layer.self_attn.num_heads dim_feedforward = encoder_layer.linear1.out_features dropout = encoder_layer.dropout.p self.layers = nn.ModuleList( [ FlashTransformerEncoderLayerPortable( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation="gelu", batch_first=True, attn_backend=attn_backend, ) for _ in range(num_layers) ] ) def forward( self, src: torch.Tensor, mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = src for layer in self.layers: x = layer(x, src_mask=mask, src_key_padding_mask=src_key_padding_mask) return x class TinyTransformerLM(nn.Module): def __init__(self, cfg: _InnerCfg) -> None: super().__init__() self.cfg = cfg vocab_size = cfg.vocab_size self.tok_embed = nn.Embedding(vocab_size, cfg.embed_dim) self.pos_encoding = PositionalEncoding(cfg.embed_dim, cfg.block_size) encoder_layer = FlashTransformerEncoderLayerPortable( d_model=cfg.embed_dim, nhead=cfg.num_heads, dim_feedforward=cfg.ff_hidden_dim, dropout=cfg.dropout, activation="gelu", batch_first=True, attn_backend=cfg.attn_backend, ) self.encoder = FlashTransformerEncoderPortable( encoder_layer, num_layers=cfg.num_layers, attn_backend=cfg.attn_backend, ) ln_dim = cfg.layernorm_dim or cfg.embed_dim head_dim = cfg.head_dim or ln_dim self.pre_ln_proj: Optional[nn.Linear] = None if ln_dim != cfg.embed_dim: self.pre_ln_proj = nn.Linear(cfg.embed_dim, ln_dim) self.ln = nn.LayerNorm(ln_dim) self.head_pre: Optional[nn.Linear] = None if head_dim != ln_dim: self.head_pre = nn.Linear(ln_dim, head_dim) self.head = nn.Linear(head_dim, vocab_size, bias=False) if self.pre_ln_proj is None and self.head_pre is None and head_dim == cfg.embed_dim: self.head.weight = self.tok_embed.weight causal = torch.triu(torch.ones(cfg.block_size, cfg.block_size, dtype=torch.bool), diagonal=1) self.register_buffer("causal_mask", causal, persistent=False) def forward(self, tokens: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.tok_embed(tokens) x = self.pos_encoding(x) seq_len = tokens.size(1) attn_mask = self.causal_mask[:seq_len, :seq_len].to(device=tokens.device) if padding_mask is not None: padding_mask = padding_mask[:, :seq_len].to(device=tokens.device, dtype=torch.bool) x = self.encoder(x, mask=attn_mask, src_key_padding_mask=padding_mask) if self.pre_ln_proj is not None: proj_dtype = self.pre_ln_proj.weight.dtype if x.dtype != proj_dtype: x = x.to(proj_dtype) x = self.pre_ln_proj(x) ln_dtype = self.ln.weight.dtype if x.dtype != ln_dtype: x = x.to(ln_dtype) x = self.ln(x) if self.head_pre is not None: head_pre_dtype = self.head_pre.weight.dtype if x.dtype != head_pre_dtype: x = x.to(head_pre_dtype) x = self.head_pre(x) head_dtype = self.head.weight.dtype if x.dtype != head_dtype: x = x.to(head_dtype) return self.head(x) class BinaryLLMForCausalLM(PreTrainedModel): config_class = BinaryLLMConfig main_input_name = "input_ids" def __init__(self, config: BinaryLLMConfig): super().__init__(config) attn_backend = getattr(config, "attn_backend", "auto") inner = _InnerCfg( block_size=int(config.max_position_embeddings), embed_dim=int(config.hidden_size), vocab_size=int(config.vocab_size), num_heads=int(config.num_attention_heads), num_layers=int(config.num_hidden_layers), ff_hidden_dim=int(config.intermediate_size), dropout=float(getattr(config, "dropout", 0.0)), layernorm_dim=None, head_dim=None, attn_backend=str(attn_backend), ) self.model = TinyTransformerLM(inner) self.post_init() def get_input_embeddings(self) -> nn.Module: return self.model.tok_embed def set_input_embeddings(self, value: nn.Module) -> None: self.model.tok_embed = value def get_output_embeddings(self) -> nn.Module: return self.model.head def set_output_embeddings(self, new_embeddings: nn.Module) -> None: self.model.head = new_embeddings def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, **kwargs, ): return { "input_ids": input_ids, "attention_mask": attention_mask, } def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs, ) -> CausalLMOutput: padding_mask = None if attention_mask is not None: padding_mask = ~attention_mask.to(torch.bool) logits = self.model(input_ids, padding_mask=padding_mask) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100, ) return CausalLMOutput(loss=loss, logits=logits)