YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
This is an optimized version of the text encoder used in flux2klein 9B. Same weights/architecture (Qwen3), just stripped down code that, under torch.compile, is 1.3x faster and uses less peak VRAM (should save a couple gigs).
qwen_model = FluxQwen3TorchEmbedder.from_pretrained("fancyfeast/flux2klein-optimized-text-embedder-9B", torch_dtype=torch.bfloat16)
from __future__ import annotations
import json
import math
from pathlib import Path
import torch
from torch import nn
from torch.nn import functional as F
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers import PreTrainedModel
class FluxQwen3TorchEmbedder(PreTrainedModel):
"""Stripped down and optimized Qwen3 specifically for Flux 2 Klein models.
In my testing this is about 1.3x faster than using the original HF implementation, and saves ~3GB of peak memory on the 8GB model.
The output_hidden_state_indices is 9, 18, 27 for both Klein 4B and Klein 9B.
"""
config_class = Qwen3Config
base_model_prefix = "flux_qwen3"
def __init__(
self,
config: Qwen3Config,
*,
output_hidden_state_indices: tuple[int, ...] = (9, 18, 27),
max_sequence_length: int = 512,
):
super().__init__(config)
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = int(getattr(config, "head_dim", self.hidden_size // self.num_attention_heads))
self.rope_theta = float(getattr(config, "rope_theta", 1000000.0))
self.output_hidden_state_indices = tuple(int(i) for i in output_hidden_state_indices)
if not self.output_hidden_state_indices:
raise ValueError("output_hidden_state_indices must not be empty")
if min(self.output_hidden_state_indices) < 1:
raise ValueError("output hidden state indices must be >= 1 for decoder layer outputs")
max_layer_needed = max(self.output_hidden_state_indices)
if max_layer_needed > int(config.num_hidden_layers):
raise ValueError(f"requested hidden state after layer {max_layer_needed}, but config.num_hidden_layers={config.num_hidden_layers}")
self.capture_slot_by_layer = {
layer_idx: slot for slot, layer_idx in enumerate(self.output_hidden_state_indices)
}
self.max_sequence_length = int(max_sequence_length)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=getattr(config, "pad_token_id", None))
self.layers = nn.ModuleList(
FluxQwen3TorchLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
)
# Built lazily/refreshed in forward so dtype/device tracks the model.
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
self.register_buffer("causal_mask", torch.empty(0, dtype=torch.bool), persistent=False)
self.post_init()
def _maybe_refresh_caches(self, *, device: torch.device, dtype: torch.dtype):
need_refresh = self.cos_cached.numel() == 0 or self.cos_cached.shape[0] < self.max_sequence_length or self.cos_cached.device != device or self.cos_cached.dtype != dtype
if not need_refresh:
return
cos, sin = _rotary_cache(
self.max_sequence_length,
self.head_dim,
self.rope_theta,
device=device,
dtype=dtype,
)
pos = torch.arange(self.max_sequence_length, device=device)
causal = pos[None, :] <= pos[:, None]
self.cos_cached = cos
self.sin_cached = sin
self.causal_mask = causal[None, None, :, :]
@classmethod
def _from_original_hf_checkpoint(cls, checkpoint_path: str, subfolder: str | None) -> "FluxQwen3TorchEmbedder":
from huggingface_hub import hf_hub_download
import safetensors.torch
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(checkpoint_path, subfolder=subfolder)
assert isinstance(cfg, Qwen3Config), f"expected Qwen3Config, got {type(cfg)}"
cfg.num_hidden_layers = 27
if cfg.layer_types is not None:
cfg.layer_types = cfg.layer_types[:27]
cfg.max_window_layers = 27
model = cls(cfg)
# Load the original checkpoint
index_path = hf_hub_download(checkpoint_path, filename="model.safetensors.index.json", subfolder=subfolder)
index = json.loads(Path(index_path).read_text())
shard_names = set(index['weight_map'].values())
original_checkpoint = {}
for shard_name in shard_names:
path = hf_hub_download(checkpoint_path, filename=shard_name, subfolder=subfolder)
shard = safetensors.torch.load_file(path)
original_checkpoint.update(shard)
# Copy weights from the original checkpoint into our model
with torch.no_grad():
model.embed_tokens.weight.copy_(original_checkpoint["model.embed_tokens.weight"])
for layer_idx in range(len(model.layers)):
layer = model.layers[layer_idx]
layer_base = f"model.layers.{layer_idx}."
layer.input_layernorm_weight.copy_(original_checkpoint[layer_base + "input_layernorm.weight"])
layer.post_attention_layernorm_weight.copy_(original_checkpoint[layer_base + "post_attention_layernorm.weight"])
q = original_checkpoint[layer_base + "self_attn.q_proj.weight"]
k = original_checkpoint[layer_base + "self_attn.k_proj.weight"]
v = original_checkpoint[layer_base + "self_attn.v_proj.weight"]
layer.qkv_proj_weight.copy_(torch.cat((q, k, v), dim=0))
layer.o_proj_weight.copy_(original_checkpoint[layer_base + "self_attn.o_proj.weight"])
layer.q_norm_weight.copy_(original_checkpoint[layer_base + "self_attn.q_norm.weight"])
layer.k_norm_weight.copy_(original_checkpoint[layer_base + "self_attn.k_norm.weight"])
gate = original_checkpoint[layer_base + "mlp.gate_proj.weight"]
up = original_checkpoint[layer_base + "mlp.up_proj.weight"]
layer.gate_up_proj_weight.copy_(torch.cat((gate, up), dim=0))
layer.down_proj_weight.copy_(original_checkpoint[layer_base + "mlp.down_proj.weight"])
return model
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
if input_ids.ndim != 2:
raise ValueError(f"expected input_ids [batch, seq], got {tuple(input_ids.shape)}")
batch, seq_len = input_ids.shape
if seq_len != self.max_sequence_length:
raise ValueError(f"sequence length {seq_len} does not match cached max {self.max_sequence_length}")
dtype = self.embed_tokens.weight.dtype
device = input_ids.device
self._maybe_refresh_caches(device=device, dtype=dtype)
key_mask = attention_mask.reshape(batch, 1, 1, seq_len).to(dtype=torch.bool)
sdpa_mask = self.causal_mask[:, :, :seq_len, :seq_len] & key_mask
cos = self.cos_cached[:seq_len]
sin = self.sin_cached[:seq_len]
hidden_states = self.embed_tokens(input_ids)
prompt_embeds = torch.empty(
batch,
seq_len,
len(self.output_hidden_state_indices) * self.hidden_size,
device=input_ids.device,
dtype=dtype,
)
for layer_number, layer in enumerate(self.layers, start=1):
hidden_states = layer(hidden_states, cos, sin, sdpa_mask)
slot = self.capture_slot_by_layer.get(layer_number)
if slot is None:
continue
start = slot * self.hidden_size
prompt_embeds[:, :, start : start + self.hidden_size].copy_(hidden_states)
return prompt_embeds
class FluxQwen3TorchLayer(nn.Module):
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.num_attention_heads = config.num_attention_heads
assert config.num_key_value_heads is not None, "num_key_value_heads must be specified in config for FluxQwen3TorchLayer"
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = int(getattr(config, "head_dim", self.hidden_size // self.num_attention_heads))
self.rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6))
self.scale = 1.0 / math.sqrt(self.head_dim)
self.q_width = self.num_attention_heads * self.head_dim
self.kv_width = self.num_key_value_heads * self.head_dim
self.k_offset = self.q_width
self.v_offset = self.q_width + self.kv_width
self.input_layernorm_weight = nn.Parameter(torch.empty(self.hidden_size))
self.post_attention_layernorm_weight = nn.Parameter(torch.empty(self.hidden_size))
self.qkv_proj_weight = nn.Parameter(torch.empty(self.q_width + 2 * self.kv_width, self.hidden_size))
self.o_proj_weight = nn.Parameter(torch.empty(self.hidden_size, self.q_width))
self.q_norm_weight = nn.Parameter(torch.empty(self.head_dim))
self.k_norm_weight = nn.Parameter(torch.empty(self.head_dim))
self.gate_up_proj_weight = nn.Parameter(torch.empty(self.intermediate_size * 2, self.hidden_size))
self.down_proj_weight = nn.Parameter(torch.empty(self.hidden_size, self.intermediate_size))
assert self.q_width == self.o_proj_weight.shape[1]
assert self.o_proj_weight.shape == (self.hidden_size, self.q_width)
assert self.qkv_proj_weight.shape == (self.q_width + 2 * self.kv_width, self.hidden_size)
def _rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x_float = x.float()
variance = x_float.pow(2).mean(dim=-1, keepdim=True)
return (x_float * torch.rsqrt(variance + self.rms_norm_eps)).to(dtype) * weight
def _head_rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x_float = x.float()
variance = x_float.pow(2).mean(dim=-1, keepdim=True)
return (x_float * torch.rsqrt(variance + self.rms_norm_eps)).to(dtype) * weight
@staticmethod
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
half = x.shape[-1] // 2
x1 = x[..., :half]
x2 = x[..., half:]
rotated = torch.cat((-x2, x1), dim=-1)
return x * cos[:, None, :] + rotated * sin[:, None, :]
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
x = self._rms_norm(hidden_states, self.input_layernorm_weight)
batch, seq_len, _ = x.shape
qkv = F.linear(x, self.qkv_proj_weight)
q_raw = qkv[:, :, : self.q_width].view(batch, seq_len, self.num_attention_heads, self.head_dim)
k_raw = qkv[:, :, self.k_offset : self.v_offset].view(
batch, seq_len, self.num_key_value_heads, self.head_dim
)
v = qkv[:, :, self.v_offset :].view(batch, seq_len, self.num_key_value_heads, self.head_dim)
q = self._apply_rope(self._head_rms_norm(q_raw, self.q_norm_weight), cos, sin).transpose(1, 2)
k = self._apply_rope(self._head_rms_norm(k_raw, self.k_norm_weight), cos, sin).transpose(1, 2)
v = v.transpose(1, 2)
attn = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attention_mask,
dropout_p=0.0,
scale=self.scale,
is_causal=False,
enable_gqa=True,
)
attn = attn.transpose(1, 2).contiguous().view(batch, seq_len, self.q_width)
hidden_states = residual + F.linear(attn, self.o_proj_weight)
residual = hidden_states
x = self._rms_norm(hidden_states, self.post_attention_layernorm_weight)
gate_up = F.linear(x, self.gate_up_proj_weight)
gate, up = gate_up.split(self.intermediate_size, dim=-1)
x = F.silu(gate) * up
hidden_states = residual + F.linear(x, self.down_proj_weight)
return hidden_states
def _rotary_cache(
seq_len: int,
head_dim: int,
rope_theta: float,
*,
device: torch.device | str,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
inv_freq = 1.0 / (
rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim)
)
pos = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(pos, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=dtype).contiguous(), emb.sin().to(dtype=dtype).contiguous()
- Downloads last month
- 25
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support