import logging import torch import torch.nn as nn from torchvision import models from transformers import ConvNextModel logger = logging.getLogger(__name__) class FusionClassifier(nn.Module): def __init__( self, num_classes, convnext_model_name="facebook/convnext-small-224" ): super().__init__() logger.info("Initializing Fusion model...") # EfficientNet-V2-S eff = models.efficientnet_v2_s( weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1 ) for param in eff.parameters(): param.requires_grad = False for param in eff.features[5].parameters(): param.requires_grad = True for param in eff.features[6].parameters(): param.requires_grad = True for param in eff.features[7].parameters(): param.requires_grad = True self.eff_features = eff.features self.eff_avgpool = eff.avgpool self.eff_out_dim = eff.classifier[1].in_features # ConvNeXt cnx = ConvNextModel.from_pretrained(convnext_model_name) for param in cnx.parameters(): param.requires_grad = False for param in cnx.encoder.stages[2].parameters(): param.requires_grad = True for param in cnx.encoder.stages[3].parameters(): param.requires_grad = True for param in cnx.layernorm.parameters(): param.requires_grad = True self.cnx_backbone = cnx self.cnx_out_dim = 768 fused_dim = self.eff_out_dim + self.cnx_out_dim self.fusion_head = nn.Sequential( nn.Dropout(0.4), nn.Linear(fused_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(0.2), nn.Linear(256, num_classes) ) logger.info("Fusion model initialized successfully.") def forward(self, pixel_values_eff, pixel_values_cnx): x_eff = self.eff_features(pixel_values_eff) x_eff = self.eff_avgpool(x_eff) x_eff = torch.flatten(x_eff, 1) cnx_out = self.cnx_backbone( pixel_values=pixel_values_cnx, return_dict=True ) x_cnx = cnx_out.pooler_output fused = torch.cat([x_eff, x_cnx], dim=1) logits = self.fusion_head(fused) return logits if __name__ == "__main__": import logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) model = FusionClassifier(num_classes=6) eff_dummy = torch.randn(2, 3, 260, 260) cnx_dummy = torch.randn(2, 3, 224, 224) output = model(eff_dummy, cnx_dummy) print("Fusion output shape:", output.shape)