| | """ |
| | Adaptive Fusion Module for Hybrid Food Classifier |
| | Combines CNN and ViT features using cross-attention mechanism |
| | """ |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Tuple |
| |
|
| | class AdaptiveFusionModule(nn.Module): |
| | """Adaptive fusion module with cross-attention""" |
| | |
| | def __init__( |
| | self, |
| | feature_dim: int = 768, |
| | hidden_dim: int = 512, |
| | num_heads: int = 8, |
| | dropout: float = 0.2, |
| | spatial_size: int = 7 |
| | ): |
| | super(AdaptiveFusionModule, self).__init__() |
| | |
| | self.feature_dim = feature_dim |
| | self.hidden_dim = hidden_dim |
| | self.num_heads = num_heads |
| | self.spatial_size = spatial_size |
| | |
| | |
| | self.cnn_to_vit_attention = nn.MultiheadAttention( |
| | embed_dim=feature_dim, |
| | num_heads=num_heads, |
| | dropout=dropout, |
| | batch_first=True |
| | ) |
| | |
| | |
| | self.vit_to_cnn_attention = nn.MultiheadAttention( |
| | embed_dim=feature_dim, |
| | num_heads=num_heads, |
| | dropout=dropout, |
| | batch_first=True |
| | ) |
| | |
| | |
| | self.self_attention = nn.MultiheadAttention( |
| | embed_dim=feature_dim, |
| | num_heads=num_heads, |
| | dropout=dropout, |
| | batch_first=True |
| | ) |
| | |
| | |
| | self.cnn_spatial_proj = nn.Sequential( |
| | nn.Linear(feature_dim, feature_dim), |
| | nn.LayerNorm(feature_dim), |
| | nn.GELU(), |
| | nn.Dropout(dropout) |
| | ) |
| | |
| | self.vit_spatial_proj = nn.Sequential( |
| | nn.Linear(feature_dim, feature_dim), |
| | nn.LayerNorm(feature_dim), |
| | nn.GELU(), |
| | nn.Dropout(dropout) |
| | ) |
| | |
| | |
| | self.global_fusion = nn.Sequential( |
| | nn.Linear(feature_dim * 2, hidden_dim), |
| | nn.LayerNorm(hidden_dim), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim, feature_dim), |
| | nn.LayerNorm(feature_dim), |
| | nn.GELU(), |
| | nn.Dropout(dropout) |
| | ) |
| | |
| | |
| | self.adaptive_weight = nn.Sequential( |
| | nn.Linear(feature_dim * 2, hidden_dim), |
| | nn.ReLU(), |
| | nn.Linear(hidden_dim, 2), |
| | nn.Softmax(dim=-1) |
| | ) |
| | |
| | |
| | self.final_proj = nn.Sequential( |
| | nn.Linear(feature_dim, hidden_dim), |
| | nn.LayerNorm(hidden_dim), |
| | nn.GELU(), |
| | nn.Dropout(dropout) |
| | ) |
| | |
| | def forward( |
| | self, |
| | cnn_spatial: torch.Tensor, |
| | cnn_global: torch.Tensor, |
| | vit_spatial: torch.Tensor, |
| | vit_global: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Forward pass |
| | |
| | Args: |
| | cnn_spatial: CNN spatial features [B, feature_dim, 7, 7] |
| | cnn_global: CNN global features [B, feature_dim] |
| | vit_spatial: ViT patch features [B, num_patches, feature_dim] |
| | vit_global: ViT CLS token features [B, feature_dim] |
| | |
| | Returns: |
| | fused_spatial: Fused spatial features [B, seq_len, feature_dim] |
| | fused_global: Fused global features [B, feature_dim] |
| | """ |
| | batch_size = cnn_spatial.size(0) |
| | |
| | |
| | cnn_spatial_seq = cnn_spatial.flatten(2).transpose(1, 2) |
| | |
| | |
| | cnn_spatial_proj = self.cnn_spatial_proj(cnn_spatial_seq) |
| | vit_spatial_proj = self.vit_spatial_proj(vit_spatial) |
| | |
| | |
| | cnn_attended, _ = self.cnn_to_vit_attention( |
| | query=cnn_spatial_proj, |
| | key=vit_spatial_proj, |
| | value=vit_spatial_proj |
| | ) |
| | |
| | |
| | vit_attended, _ = self.vit_to_cnn_attention( |
| | query=vit_spatial_proj, |
| | key=cnn_spatial_proj, |
| | value=cnn_spatial_proj |
| | ) |
| | |
| | |
| | |
| | combined_spatial = torch.cat([ |
| | cnn_attended + cnn_spatial_proj, |
| | vit_attended + vit_spatial_proj |
| | ], dim=1) |
| | |
| | |
| | fused_spatial, _ = self.self_attention( |
| | query=combined_spatial, |
| | key=combined_spatial, |
| | value=combined_spatial |
| | ) |
| | |
| | |
| | global_concat = torch.cat([cnn_global, vit_global], dim=-1) |
| | fused_global_base = self.global_fusion(global_concat) |
| | |
| | |
| | weights = self.adaptive_weight(global_concat) |
| | cnn_weight = weights[:, 0:1] |
| | vit_weight = weights[:, 1:2] |
| | |
| | |
| | fused_global = (cnn_weight * cnn_global + |
| | vit_weight * vit_global + |
| | fused_global_base) / 2 |
| | |
| | |
| | fused_global = self.final_proj(fused_global) |
| | |
| | return fused_spatial, fused_global |
| | |
| | def get_output_dim(self) -> int: |
| | """Get output feature dimension""" |
| | return self.hidden_dim |