YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

This is an optimized version of the text encoder used in flux2klein 9B. Same weights/architecture (Qwen3), just stripped down code that, under torch.compile, is 1.3x faster and uses less peak VRAM (should save a couple gigs).

qwen_model = FluxQwen3TorchEmbedder.from_pretrained("fancyfeast/flux2klein-optimized-text-embedder-9B", torch_dtype=torch.bfloat16)
from __future__ import annotations

import json
import math
from pathlib import Path

import torch
from torch import nn
from torch.nn import functional as F
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers import PreTrainedModel


class FluxQwen3TorchEmbedder(PreTrainedModel):
    """Stripped down and optimized Qwen3 specifically for Flux 2 Klein models.
    In my testing this is about 1.3x faster than using the original HF implementation, and saves ~3GB of peak memory on the 8GB model.

    The output_hidden_state_indices is 9, 18, 27 for both Klein 4B and Klein 9B.
    """
    config_class = Qwen3Config
    base_model_prefix = "flux_qwen3"

    def __init__(
        self,
        config: Qwen3Config,
        *,
        output_hidden_state_indices: tuple[int, ...] = (9, 18, 27),
        max_sequence_length: int = 512,
    ):
        super().__init__(config)

        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.head_dim = int(getattr(config, "head_dim", self.hidden_size // self.num_attention_heads))
        self.rope_theta = float(getattr(config, "rope_theta", 1000000.0))

        self.output_hidden_state_indices = tuple(int(i) for i in output_hidden_state_indices)
        if not self.output_hidden_state_indices:
            raise ValueError("output_hidden_state_indices must not be empty")
        if min(self.output_hidden_state_indices) < 1:
            raise ValueError("output hidden state indices must be >= 1 for decoder layer outputs")
        
        max_layer_needed = max(self.output_hidden_state_indices)
        if max_layer_needed > int(config.num_hidden_layers):
            raise ValueError(f"requested hidden state after layer {max_layer_needed}, but config.num_hidden_layers={config.num_hidden_layers}")

        self.capture_slot_by_layer = {
            layer_idx: slot for slot, layer_idx in enumerate(self.output_hidden_state_indices)
        }
        self.max_sequence_length = int(max_sequence_length)

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=getattr(config, "pad_token_id", None))
        self.layers = nn.ModuleList(
            FluxQwen3TorchLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
        )

        # Built lazily/refreshed in forward so dtype/device tracks the model.
        self.register_buffer("cos_cached", torch.empty(0), persistent=False)
        self.register_buffer("sin_cached", torch.empty(0), persistent=False)
        self.register_buffer("causal_mask", torch.empty(0, dtype=torch.bool), persistent=False)

        self.post_init()
    
    def _maybe_refresh_caches(self, *, device: torch.device, dtype: torch.dtype):
        need_refresh = self.cos_cached.numel() == 0 or self.cos_cached.shape[0] < self.max_sequence_length or self.cos_cached.device != device or self.cos_cached.dtype != dtype
        if not need_refresh:
            return
        
        cos, sin = _rotary_cache(
            self.max_sequence_length,
            self.head_dim,
            self.rope_theta,
            device=device,
            dtype=dtype,
        )

        pos = torch.arange(self.max_sequence_length, device=device)
        causal = pos[None, :] <= pos[:, None]

        self.cos_cached = cos
        self.sin_cached = sin
        self.causal_mask = causal[None, None, :, :]

    @classmethod
    def _from_original_hf_checkpoint(cls, checkpoint_path: str, subfolder: str | None) -> "FluxQwen3TorchEmbedder":
        from huggingface_hub import hf_hub_download
        import safetensors.torch
        from transformers import AutoConfig

        cfg = AutoConfig.from_pretrained(checkpoint_path, subfolder=subfolder)
        assert isinstance(cfg, Qwen3Config), f"expected Qwen3Config, got {type(cfg)}"
        cfg.num_hidden_layers = 27
        if cfg.layer_types is not None:
            cfg.layer_types = cfg.layer_types[:27]
        cfg.max_window_layers = 27
        model = cls(cfg)

        # Load the original checkpoint
        index_path = hf_hub_download(checkpoint_path, filename="model.safetensors.index.json", subfolder=subfolder)
        index = json.loads(Path(index_path).read_text())
        shard_names = set(index['weight_map'].values())
        original_checkpoint = {}

        for shard_name in shard_names:
            path = hf_hub_download(checkpoint_path, filename=shard_name, subfolder=subfolder)
            shard = safetensors.torch.load_file(path)
            original_checkpoint.update(shard)

        # Copy weights from the original checkpoint into our model
        with torch.no_grad():
            model.embed_tokens.weight.copy_(original_checkpoint["model.embed_tokens.weight"])

            for layer_idx in range(len(model.layers)):
                layer = model.layers[layer_idx]
                layer_base = f"model.layers.{layer_idx}."

                layer.input_layernorm_weight.copy_(original_checkpoint[layer_base + "input_layernorm.weight"])
                layer.post_attention_layernorm_weight.copy_(original_checkpoint[layer_base + "post_attention_layernorm.weight"])
                q = original_checkpoint[layer_base + "self_attn.q_proj.weight"]
                k = original_checkpoint[layer_base + "self_attn.k_proj.weight"]
                v = original_checkpoint[layer_base + "self_attn.v_proj.weight"]
                layer.qkv_proj_weight.copy_(torch.cat((q, k, v), dim=0))
                layer.o_proj_weight.copy_(original_checkpoint[layer_base + "self_attn.o_proj.weight"])
                layer.q_norm_weight.copy_(original_checkpoint[layer_base + "self_attn.q_norm.weight"])
                layer.k_norm_weight.copy_(original_checkpoint[layer_base + "self_attn.k_norm.weight"])
                gate = original_checkpoint[layer_base + "mlp.gate_proj.weight"]
                up = original_checkpoint[layer_base + "mlp.up_proj.weight"]
                layer.gate_up_proj_weight.copy_(torch.cat((gate, up), dim=0))
                layer.down_proj_weight.copy_(original_checkpoint[layer_base + "mlp.down_proj.weight"])
        
        return model

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        if input_ids.ndim != 2:
            raise ValueError(f"expected input_ids [batch, seq], got {tuple(input_ids.shape)}")
        
        batch, seq_len = input_ids.shape

        if seq_len != self.max_sequence_length:
            raise ValueError(f"sequence length {seq_len} does not match cached max {self.max_sequence_length}")

        dtype = self.embed_tokens.weight.dtype
        device = input_ids.device
        self._maybe_refresh_caches(device=device, dtype=dtype)

        key_mask = attention_mask.reshape(batch, 1, 1, seq_len).to(dtype=torch.bool)
        sdpa_mask = self.causal_mask[:, :, :seq_len, :seq_len] & key_mask

        cos = self.cos_cached[:seq_len]
        sin = self.sin_cached[:seq_len]

        hidden_states = self.embed_tokens(input_ids)

        prompt_embeds = torch.empty(
            batch,
            seq_len,
            len(self.output_hidden_state_indices) * self.hidden_size,
            device=input_ids.device,
            dtype=dtype,
        )

        for layer_number, layer in enumerate(self.layers, start=1):
            hidden_states = layer(hidden_states, cos, sin, sdpa_mask)
            slot = self.capture_slot_by_layer.get(layer_number)
            if slot is None:
                continue

            start = slot * self.hidden_size
            prompt_embeds[:, :, start : start + self.hidden_size].copy_(hidden_states)

        return prompt_embeds


class FluxQwen3TorchLayer(nn.Module):
    def __init__(self, config: Qwen3Config, layer_idx: int):
        super().__init__()

        self.layer_idx = layer_idx
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.num_attention_heads = config.num_attention_heads
        assert config.num_key_value_heads is not None, "num_key_value_heads must be specified in config for FluxQwen3TorchLayer"
        self.num_key_value_heads = config.num_key_value_heads
        self.head_dim = int(getattr(config, "head_dim", self.hidden_size // self.num_attention_heads))
        self.rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6))
        self.scale = 1.0 / math.sqrt(self.head_dim)

        self.q_width = self.num_attention_heads * self.head_dim
        self.kv_width = self.num_key_value_heads * self.head_dim
        self.k_offset = self.q_width
        self.v_offset = self.q_width + self.kv_width

        self.input_layernorm_weight = nn.Parameter(torch.empty(self.hidden_size))
        self.post_attention_layernorm_weight = nn.Parameter(torch.empty(self.hidden_size))

        self.qkv_proj_weight = nn.Parameter(torch.empty(self.q_width + 2 * self.kv_width, self.hidden_size))
        self.o_proj_weight = nn.Parameter(torch.empty(self.hidden_size, self.q_width))

        self.q_norm_weight = nn.Parameter(torch.empty(self.head_dim))
        self.k_norm_weight = nn.Parameter(torch.empty(self.head_dim))

        self.gate_up_proj_weight = nn.Parameter(torch.empty(self.intermediate_size * 2, self.hidden_size))
        self.down_proj_weight = nn.Parameter(torch.empty(self.hidden_size, self.intermediate_size))

        assert self.q_width == self.o_proj_weight.shape[1]
        assert self.o_proj_weight.shape == (self.hidden_size, self.q_width)
        assert self.qkv_proj_weight.shape == (self.q_width + 2 * self.kv_width, self.hidden_size)

    def _rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
        dtype = x.dtype
        x_float = x.float()
        variance = x_float.pow(2).mean(dim=-1, keepdim=True)
        return (x_float * torch.rsqrt(variance + self.rms_norm_eps)).to(dtype) * weight

    def _head_rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
        dtype = x.dtype
        x_float = x.float()
        variance = x_float.pow(2).mean(dim=-1, keepdim=True)
        return (x_float * torch.rsqrt(variance + self.rms_norm_eps)).to(dtype) * weight

    @staticmethod
    def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
        half = x.shape[-1] // 2
        x1 = x[..., :half]
        x2 = x[..., half:]
        rotated = torch.cat((-x2, x1), dim=-1)
        return x * cos[:, None, :] + rotated * sin[:, None, :]

    def forward(
        self,
        hidden_states: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        x = self._rms_norm(hidden_states, self.input_layernorm_weight)

        batch, seq_len, _ = x.shape
        qkv = F.linear(x, self.qkv_proj_weight)
        q_raw = qkv[:, :, : self.q_width].view(batch, seq_len, self.num_attention_heads, self.head_dim)
        k_raw = qkv[:, :, self.k_offset : self.v_offset].view(
            batch, seq_len, self.num_key_value_heads, self.head_dim
        )
        v = qkv[:, :, self.v_offset :].view(batch, seq_len, self.num_key_value_heads, self.head_dim)

        q = self._apply_rope(self._head_rms_norm(q_raw, self.q_norm_weight), cos, sin).transpose(1, 2)
        k = self._apply_rope(self._head_rms_norm(k_raw, self.k_norm_weight), cos, sin).transpose(1, 2)
        v = v.transpose(1, 2)

        attn = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=attention_mask,
            dropout_p=0.0,
            scale=self.scale,
            is_causal=False,
            enable_gqa=True,
        )
        attn = attn.transpose(1, 2).contiguous().view(batch, seq_len, self.q_width)
        hidden_states = residual + F.linear(attn, self.o_proj_weight)

        residual = hidden_states
        x = self._rms_norm(hidden_states, self.post_attention_layernorm_weight)

        gate_up = F.linear(x, self.gate_up_proj_weight)
        gate, up = gate_up.split(self.intermediate_size, dim=-1)
        x = F.silu(gate) * up

        hidden_states = residual + F.linear(x, self.down_proj_weight)
        return hidden_states


def _rotary_cache(
    seq_len: int,
    head_dim: int,
    rope_theta: float,
    *,
    device: torch.device | str,
    dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
    inv_freq = 1.0 / (
        rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim)
    )
    pos = torch.arange(seq_len, dtype=torch.float32, device=device)
    freqs = torch.outer(pos, inv_freq)
    emb = torch.cat((freqs, freqs), dim=-1)
    return emb.cos().to(dtype=dtype).contiguous(), emb.sin().to(dtype=dtype).contiguous()
Downloads last month
25
Safetensors
Model size
6B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support