from typing import Any import torch from torch import nn import math from fractions import Fraction from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel import torch.nn.functional as F class QFormerAttention(nn.Module): """Multi-headed self-attention for QFormer with SDPA/Flash Attention support""" def __init__(self, hidden_size, num_heads, attn_bias=False, attention_dropout=0.0): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.attention_dropout = attention_dropout if self.head_dim * num_heads != hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {hidden_size} " f"and `num_heads`: {num_heads})." ) # Separate Q, K, V projections self.q_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) self.k_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) self.v_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) self.o_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) def forward(self, hidden_states, attention_mask=None): """ Args: hidden_states: (B, seq_len, hidden_size) attention_mask: optional attention mask Returns: (B, seq_len, hidden_size) """ batch_size, seq_len, _ = hidden_states.shape # Project and reshape to (B, num_heads, seq_len, head_dim) query_states = self.q_proj(hidden_states).view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = self.k_proj(hidden_states).view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) value_states = self.v_proj(hidden_states).view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) # Use PyTorch's scaled_dot_product_attention (SDPA) # This automatically uses Flash Attention when available attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=False, ) # Reshape back to (B, seq_len, hidden_size) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output class QFormerMLP(nn.Module): """MLP for QFormer""" def __init__(self, hidden_size, intermediate_size, mlp_bias=False): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=mlp_bias) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=mlp_bias) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=mlp_bias) self.act_fn = nn.GELU() def forward(self, x): """ Args: x: (B, seq_len, hidden_size) Returns: (B, seq_len, hidden_size) """ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class QFormerLayer(nn.Module): """Single transformer layer with self-attention and MLP""" def __init__(self, hidden_size, num_heads, intermediate_size): super().__init__() self.hidden_size = hidden_size self.attention = QFormerAttention(hidden_size, num_heads) self.attention_norm = nn.LayerNorm(hidden_size) self.mlp = QFormerMLP(hidden_size, intermediate_size) self.mlp_norm = nn.LayerNorm(hidden_size) def forward(self, hidden_states, attention_mask=None): """ Args: hidden_states: (B, seq_len, hidden_size) attention_mask: optional attention mask Returns: (B, seq_len, hidden_size) """ # Self-attention with residual and pre-norm residual = hidden_states hidden_states = self.attention_norm(hidden_states) hidden_states: Any = self.attention(hidden_states, attention_mask) hidden_states = residual + hidden_states # MLP with residual and pre-norm residual = hidden_states hidden_states = self.mlp_norm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class SimplifiedQFormer(nn.Module): """ Simplified QFormer with full self-attention between queries and inputs. This replaces Blip2QFormerModel with a cleaner implementation. """ def __init__(self, hidden_size, num_heads=8, num_layers=1, intermediate_size=None): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.num_layers = num_layers if intermediate_size is None: intermediate_size = hidden_size * 4 # Create transformer layers self.layers = nn.ModuleList([ QFormerLayer(hidden_size, num_heads, intermediate_size) for _ in range(num_layers) ]) self.norm = nn.LayerNorm(hidden_size) def forward(self, query_embeds, encoder_hidden_states): """ Args: query_embeds: (B, num_queries, hidden_size) - learnable queries encoder_hidden_states: (B, num_tokens, hidden_size) - input features Returns: (B, num_queries, hidden_size) - output features """ # Concatenate queries and encoder states for full self-attention # Shape: (B, num_queries + num_tokens, hidden_size) hidden_states = torch.cat([query_embeds, encoder_hidden_states], dim=1) # Apply transformer layers for layer in self.layers: hidden_states = layer(hidden_states) # Extract only the query outputs num_queries = query_embeds.shape[1] output = hidden_states[:, :num_queries, :] return self.norm(output) class InterpolateDownsampler: def __init__(self, config, mode="area"): self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate)) self.mode = mode def __call__(self, image_features): batch_size, _, dim = image_features.size() up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim] # interpolate expects B,C,H,W large_image_permuted = image_features.view(up_shape).permute(0,3,1,2) small_image_permuted = torch.nn.functional.interpolate( large_image_permuted, size=(self.new_image_side, self.new_image_side), mode=self.mode, ) # back to B,H*W,C final = small_image_permuted.permute(0,2,3,1).flatten(1,2) return final class QFormerDownsampler(nn.Module): def __init__(self, config): super().__init__() llm_hidden_size = config.text_config.hidden_size self.interpolate = InterpolateDownsampler(config) configuration = Blip2QFormerConfig(hidden_size=llm_hidden_size, num_attention_heads=32, intermediate_size=4096, num_hidden_layers=1, encoder_hidden_size=llm_hidden_size, cross_attention_frequency=1, max_position_embeddings=2048, use_qformer_text_input=False, ) self.qformer = Blip2QFormerModel(configuration) self.image_side = config.vision_config.image_size // config.vision_config.patch_size down = Fraction(config.downsample_rate) query_side = self.image_side * down assert query_side.denominator == 1, "downsample_rate must make query_side an integer" self.query_side = query_side.numerator # query length is cubical for seamless integration with llava next self.query_length = self.query_side ** 2 embed_std = 1 / math.sqrt(llm_hidden_size) self.query = nn.Parameter(torch.randn(1, self.query_length, llm_hidden_size) * embed_std) # qformer model doesn't have positional embeddings, adding to the flat patches self.image_positions = nn.Parameter(torch.randn(1, self.image_side ** 2, llm_hidden_size) * embed_std) def forward(self, image_features): batch_size, image_size, dim = image_features.size() interpolated = self.interpolate(image_features) query_output = self.qformer( query_embeds=self.query + interpolated, encoder_hidden_states=image_features + self.image_positions, return_dict=True, ).last_hidden_state return query_output + interpolated class WindowQFormerDownsampler(nn.Module): def __init__(self, config, use_simplified_qformer=False): super().__init__() llm_hidden_size = config.text_config.hidden_size vision_hidden_size = config.vision_config.hidden_size self.interpolate = InterpolateDownsampler(config) self.use_simplified_qformer = use_simplified_qformer # Choose between SimplifiedQFormer and Blip2QFormerModel if use_simplified_qformer: # Use our simplified QFormer with full self-attention self.qformer = SimplifiedQFormer( hidden_size=vision_hidden_size, num_heads=18, num_layers=1, intermediate_size=4096 ) else: # Use original Blip2QFormerModel with cross-attention configuration = Blip2QFormerConfig( hidden_size=vision_hidden_size, num_attention_heads=16, intermediate_size=4096, num_hidden_layers=1, encoder_hidden_size=vision_hidden_size, cross_attention_frequency=1, max_position_embeddings=2048, use_qformer_text_input=False, ) self.qformer = Blip2QFormerModel(configuration) self.image_side = config.vision_config.image_size // config.vision_config.patch_size downsample_rate = Fraction(config.downsample_rate, _normalize=False) self.query_side, self.window_side = downsample_rate.as_integer_ratio() # query length is cubical for seamless integration with llava next self.query_length = self.query_side ** 2 embed_std = 1 / math.sqrt(vision_hidden_size) self.query = nn.Parameter(torch.randn(1, self.query_length, vision_hidden_size) * embed_std) # qformer model doesn't have positional embeddings, adding to the flat patches self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, vision_hidden_size) * embed_std) self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=False) def _win(self, x, side, win): """ (B, side*side, C) raster -> (B*n*n, win*win, C) where n=side//win windows are raster-ordered, and tokens inside each window are raster-ordered. """ B, _, C = x.shape n = side // win return ( x.view(B, side, side, C) .view(B, n, win, n, win, C) .transpose(2, 3) # (B, n, n, win, win, C) .flatten(0, 2) # (B*n*n, win, win, C) .flatten(1, 2) # (B*n*n, win*win, C) ) def _unwin(self, xw, n, win): """ (B*n*n, win*win, C) -> (B, (n*win)^2, C) raster """ Bnn, _, C = xw.shape assert Bnn % (n * n) == 0 B = Bnn // (n * n) side = n * win return ( xw.view(B, n, n, win, win, C) .transpose(2, 3) # (B, n, win, n, win, C) .contiguous() .view(B, side, side, C) .flatten(1, 2) ) def forward(self, image_features): B, HW, C = image_features.shape assert HW == self.image_side * self.image_side n = self.image_side // self.window_side enc = self._win(image_features, self.image_side, self.window_side) # (B*n^2, w^2, C) interpolated = self.interpolate(image_features) # (B, new_side^2, C) raster new_side = n * self.query_side interpolated_w = self._win(interpolated, new_side, self.query_side) # (B*n^2, q^2, C) # Apply QFormer based on the chosen mechanism if self.use_simplified_qformer: # SimplifiedQFormer: full self-attention between queries and inputs # Broadcasting handles batch dimension automatically out_w = self.qformer( query_embeds=self.query + interpolated_w, encoder_hidden_states=enc + self.image_positions ) # (B*n^2, q^2, C) else: # Blip2QFormerModel: cross-attention mechanism out_w = self.qformer( query_embeds=self.query + interpolated_w, encoder_hidden_states=enc + self.image_positions, return_dict=True, ).last_hidden_state # (B*n^2, q^2, C) out = self._unwin(out_w, n=n, win=self.query_side) # (B, new_side^2, C) raster return self.out_linear(out + interpolated)