|
|
from typing import Any |
|
|
import torch |
|
|
from torch import nn |
|
|
import math |
|
|
from fractions import Fraction |
|
|
from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig |
|
|
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class QFormerAttention(nn.Module): |
|
|
"""Multi-headed self-attention for QFormer with SDPA/Flash Attention support""" |
|
|
|
|
|
def __init__(self, hidden_size, num_heads, attn_bias=False, attention_dropout=0.0): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = hidden_size // num_heads |
|
|
self.attention_dropout = attention_dropout |
|
|
|
|
|
if self.head_dim * num_heads != hidden_size: |
|
|
raise ValueError( |
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {hidden_size} " |
|
|
f"and `num_heads`: {num_heads})." |
|
|
) |
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) |
|
|
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) |
|
|
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) |
|
|
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) |
|
|
|
|
|
def forward(self, hidden_states, attention_mask=None): |
|
|
""" |
|
|
Args: |
|
|
hidden_states: (B, seq_len, hidden_size) |
|
|
attention_mask: optional attention mask |
|
|
Returns: |
|
|
(B, seq_len, hidden_size) |
|
|
""" |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
|
|
|
|
|
|
query_states = self.q_proj(hidden_states).view( |
|
|
batch_size, seq_len, self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
key_states = self.k_proj(hidden_states).view( |
|
|
batch_size, seq_len, self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
value_states = self.v_proj(hidden_states).view( |
|
|
batch_size, seq_len, self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attn_mask=attention_mask, |
|
|
dropout_p=self.attention_dropout if self.training else 0.0, |
|
|
is_causal=False, |
|
|
) |
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size) |
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
return attn_output |
|
|
|
|
|
|
|
|
class QFormerMLP(nn.Module): |
|
|
"""MLP for QFormer""" |
|
|
|
|
|
def __init__(self, hidden_size, intermediate_size, mlp_bias=False): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.intermediate_size = intermediate_size |
|
|
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=mlp_bias) |
|
|
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=mlp_bias) |
|
|
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=mlp_bias) |
|
|
self.act_fn = nn.GELU() |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: (B, seq_len, hidden_size) |
|
|
Returns: |
|
|
(B, seq_len, hidden_size) |
|
|
""" |
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
|
|
|
class QFormerLayer(nn.Module): |
|
|
"""Single transformer layer with self-attention and MLP""" |
|
|
|
|
|
def __init__(self, hidden_size, num_heads, intermediate_size): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
self.attention = QFormerAttention(hidden_size, num_heads) |
|
|
self.attention_norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
self.mlp = QFormerMLP(hidden_size, intermediate_size) |
|
|
self.mlp_norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
def forward(self, hidden_states, attention_mask=None): |
|
|
""" |
|
|
Args: |
|
|
hidden_states: (B, seq_len, hidden_size) |
|
|
attention_mask: optional attention mask |
|
|
Returns: |
|
|
(B, seq_len, hidden_size) |
|
|
""" |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.attention_norm(hidden_states) |
|
|
hidden_states: Any = self.attention(hidden_states, attention_mask) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.mlp_norm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class SimplifiedQFormer(nn.Module): |
|
|
""" |
|
|
Simplified QFormer with full self-attention between queries and inputs. |
|
|
This replaces Blip2QFormerModel with a cleaner implementation. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size, num_heads=8, num_layers=1, intermediate_size=None): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.num_heads = num_heads |
|
|
self.num_layers = num_layers |
|
|
|
|
|
if intermediate_size is None: |
|
|
intermediate_size = hidden_size * 4 |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
QFormerLayer(hidden_size, num_heads, intermediate_size) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
def forward(self, query_embeds, encoder_hidden_states): |
|
|
""" |
|
|
Args: |
|
|
query_embeds: (B, num_queries, hidden_size) - learnable queries |
|
|
encoder_hidden_states: (B, num_tokens, hidden_size) - input features |
|
|
|
|
|
Returns: |
|
|
(B, num_queries, hidden_size) - output features |
|
|
""" |
|
|
|
|
|
|
|
|
hidden_states = torch.cat([query_embeds, encoder_hidden_states], dim=1) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
hidden_states = layer(hidden_states) |
|
|
|
|
|
|
|
|
num_queries = query_embeds.shape[1] |
|
|
output = hidden_states[:, :num_queries, :] |
|
|
|
|
|
return self.norm(output) |
|
|
|
|
|
|
|
|
|
|
|
class InterpolateDownsampler: |
|
|
def __init__(self, config, mode="area"): |
|
|
self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size |
|
|
self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate)) |
|
|
self.mode = mode |
|
|
|
|
|
def __call__(self, image_features): |
|
|
batch_size, _, dim = image_features.size() |
|
|
up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim] |
|
|
|
|
|
large_image_permuted = image_features.view(up_shape).permute(0,3,1,2) |
|
|
small_image_permuted = torch.nn.functional.interpolate( |
|
|
large_image_permuted, size=(self.new_image_side, self.new_image_side), |
|
|
mode=self.mode, |
|
|
) |
|
|
|
|
|
final = small_image_permuted.permute(0,2,3,1).flatten(1,2) |
|
|
return final |
|
|
|
|
|
|
|
|
class QFormerDownsampler(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
llm_hidden_size = config.text_config.hidden_size |
|
|
self.interpolate = InterpolateDownsampler(config) |
|
|
|
|
|
configuration = Blip2QFormerConfig(hidden_size=llm_hidden_size, |
|
|
num_attention_heads=32, |
|
|
intermediate_size=4096, |
|
|
num_hidden_layers=1, |
|
|
encoder_hidden_size=llm_hidden_size, |
|
|
cross_attention_frequency=1, |
|
|
max_position_embeddings=2048, |
|
|
use_qformer_text_input=False, |
|
|
) |
|
|
self.qformer = Blip2QFormerModel(configuration) |
|
|
|
|
|
|
|
|
self.image_side = config.vision_config.image_size // config.vision_config.patch_size |
|
|
down = Fraction(config.downsample_rate) |
|
|
query_side = self.image_side * down |
|
|
assert query_side.denominator == 1, "downsample_rate must make query_side an integer" |
|
|
self.query_side = query_side.numerator |
|
|
|
|
|
self.query_length = self.query_side ** 2 |
|
|
embed_std = 1 / math.sqrt(llm_hidden_size) |
|
|
self.query = nn.Parameter(torch.randn(1, self.query_length, llm_hidden_size) * embed_std) |
|
|
|
|
|
self.image_positions = nn.Parameter(torch.randn(1, self.image_side ** 2, llm_hidden_size) * embed_std) |
|
|
|
|
|
|
|
|
def forward(self, image_features): |
|
|
batch_size, image_size, dim = image_features.size() |
|
|
interpolated = self.interpolate(image_features) |
|
|
query_output = self.qformer( |
|
|
query_embeds=self.query + interpolated, |
|
|
encoder_hidden_states=image_features + self.image_positions, |
|
|
return_dict=True, |
|
|
).last_hidden_state |
|
|
|
|
|
return query_output + interpolated |
|
|
|
|
|
class WindowQFormerDownsampler(nn.Module): |
|
|
def __init__(self, config, use_simplified_qformer=False): |
|
|
super().__init__() |
|
|
llm_hidden_size = config.text_config.hidden_size |
|
|
vision_hidden_size = config.vision_config.hidden_size |
|
|
self.interpolate = InterpolateDownsampler(config) |
|
|
self.use_simplified_qformer = use_simplified_qformer |
|
|
|
|
|
|
|
|
if use_simplified_qformer: |
|
|
|
|
|
self.qformer = SimplifiedQFormer( |
|
|
hidden_size=vision_hidden_size, |
|
|
num_heads=18, |
|
|
num_layers=1, |
|
|
intermediate_size=4096 |
|
|
) |
|
|
else: |
|
|
|
|
|
configuration = Blip2QFormerConfig( |
|
|
hidden_size=vision_hidden_size, |
|
|
num_attention_heads=16, |
|
|
intermediate_size=4096, |
|
|
num_hidden_layers=1, |
|
|
encoder_hidden_size=vision_hidden_size, |
|
|
cross_attention_frequency=1, |
|
|
max_position_embeddings=2048, |
|
|
use_qformer_text_input=False, |
|
|
) |
|
|
self.qformer = Blip2QFormerModel(configuration) |
|
|
|
|
|
self.image_side = config.vision_config.image_size // config.vision_config.patch_size |
|
|
downsample_rate = Fraction(config.downsample_rate, _normalize=False) |
|
|
self.query_side, self.window_side = downsample_rate.as_integer_ratio() |
|
|
|
|
|
self.query_length = self.query_side ** 2 |
|
|
embed_std = 1 / math.sqrt(vision_hidden_size) |
|
|
self.query = nn.Parameter(torch.randn(1, self.query_length, vision_hidden_size) * embed_std) |
|
|
|
|
|
self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, vision_hidden_size) * embed_std) |
|
|
self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=False) |
|
|
|
|
|
|
|
|
def _win(self, x, side, win): |
|
|
""" |
|
|
(B, side*side, C) raster -> (B*n*n, win*win, C) where n=side//win |
|
|
windows are raster-ordered, and tokens inside each window are raster-ordered. |
|
|
""" |
|
|
B, _, C = x.shape |
|
|
n = side // win |
|
|
return ( |
|
|
x.view(B, side, side, C) |
|
|
.view(B, n, win, n, win, C) |
|
|
.transpose(2, 3) |
|
|
.flatten(0, 2) |
|
|
.flatten(1, 2) |
|
|
) |
|
|
|
|
|
def _unwin(self, xw, n, win): |
|
|
""" |
|
|
(B*n*n, win*win, C) -> (B, (n*win)^2, C) raster |
|
|
""" |
|
|
Bnn, _, C = xw.shape |
|
|
assert Bnn % (n * n) == 0 |
|
|
B = Bnn // (n * n) |
|
|
side = n * win |
|
|
return ( |
|
|
xw.view(B, n, n, win, win, C) |
|
|
.transpose(2, 3) |
|
|
.contiguous() |
|
|
.view(B, side, side, C) |
|
|
.flatten(1, 2) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, image_features): |
|
|
B, HW, C = image_features.shape |
|
|
assert HW == self.image_side * self.image_side |
|
|
n = self.image_side // self.window_side |
|
|
|
|
|
enc = self._win(image_features, self.image_side, self.window_side) |
|
|
|
|
|
interpolated = self.interpolate(image_features) |
|
|
new_side = n * self.query_side |
|
|
interpolated_w = self._win(interpolated, new_side, self.query_side) |
|
|
|
|
|
|
|
|
if self.use_simplified_qformer: |
|
|
|
|
|
|
|
|
out_w = self.qformer( |
|
|
query_embeds=self.query + interpolated_w, |
|
|
encoder_hidden_states=enc + self.image_positions |
|
|
) |
|
|
else: |
|
|
|
|
|
out_w = self.qformer( |
|
|
query_embeds=self.query + interpolated_w, |
|
|
encoder_hidden_states=enc + self.image_positions, |
|
|
return_dict=True, |
|
|
).last_hidden_state |
|
|
|
|
|
out = self._unwin(out_w, n=n, win=self.query_side) |
|
|
|
|
|
return self.out_linear(out + interpolated) |
|
|
|