""" Complete working script to load ConceptFrameMet from HuggingFace with ALL weights. This properly reconstructs the source_qa_model from checkpoint weights. """ from huggingface_hub import hf_hub_download import torch import torch.nn as nn from transformers import RobertaModel, RobertaTokenizer, RobertaForSequenceClassification, RobertaConfig import sys import os # Download files print("Downloading from HuggingFace...") weights_path = hf_hub_download("nixie1981/ConceptFrameMet", "pytorch_model.bin") labels_path = hf_hub_download("nixie1981/ConceptFrameMet", "source_labels.json") # Load checkpoint print("Loading checkpoint...") state_dict = torch.load(weights_path, map_location='cpu') print(f"Checkpoint has {len(state_dict)} keys") # Check what's in the checkpoint has_source_qa = any(k.startswith('source_qa_model.') for k in state_dict.keys()) print(f"Has source_qa_model weights: {has_source_qa}") if has_source_qa: # Count source_qa_model keys source_keys = [k for k in state_dict.keys() if k.startswith('source_qa_model.')] print(f"Source QA model has {len(source_keys)} keys") # Extract source_qa_model architecture from keys # Looking for: source_qa_model.roberta.*, source_qa_model.frame_finder.*, source_qa_model.source_classifier.* has_frame_finder = any('frame_finder' in k for k in source_keys) has_source_classifier = any('source_classifier' in k for k in source_keys) print(f" - Has frame_finder: {has_frame_finder}") print(f" - Has source_classifier: {has_source_classifier}") if has_frame_finder and has_source_classifier: print("\nThis is a TrueMultiTaskModel (frame + source)!") print("Creating source_qa_model structure...") # Get num_frames and num_sources from checkpoint frame_weight_key = 'source_qa_model.frame_finder.classifier.out_proj.weight' source_weight_key = 'source_qa_model.source_classifier.weight' num_frames = state_dict[frame_weight_key].shape[0] if frame_weight_key in state_dict else None num_sources = state_dict[source_weight_key].shape[0] if source_weight_key in state_dict else None print(f" - num_frames: {num_frames}") print(f" - num_sources: {num_sources}") if num_frames and num_sources: # CREATE the source_qa_model structure! config = RobertaConfig.from_pretrained('roberta-base') # Check actual source_classifier shape from checkpoint source_classifier_weight = state_dict.get('source_qa_model.source_classifier.weight') source_classifier_input_size = source_classifier_weight.shape[1] if source_classifier_weight is not None else None print(f" - source_classifier input size: {source_classifier_input_size}") class TrueMultiTaskModel(nn.Module): def __init__(self, config, num_frames, num_sources, source_input_size): super().__init__() self.config = config self.num_frames = num_frames self.num_sources = num_sources self.roberta = RobertaModel(config) self.frame_finder = RobertaForSequenceClassification(config) self.frame_finder.classifier = nn.Linear(config.hidden_size, num_frames) # Source classifier - use actual size from checkpoint self.dropout = nn.Dropout(config.hidden_dropout_prob) self.source_classifier = nn.Linear(source_input_size, num_sources) def forward(self, input_ids=None, attention_mask=None, frame_input_ids=None, frame_attention_mask=None, **kwargs): # Frame prediction frame_outputs = self.frame_finder(input_ids=frame_input_ids, attention_mask=frame_attention_mask) frame_logits = frame_outputs.logits # Source prediction if input_ids is not None: source_outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) pooled_output = source_outputs.pooler_output combined = torch.cat([pooled_output, frame_logits], dim=1) combined = self.dropout(combined) logits = self.source_classifier(combined) class Output: pass output = Output() output.logits = logits return output class Output: pass output = Output() output.logits = frame_logits return output # Create and load source_qa_model = TrueMultiTaskModel(config, num_frames, num_sources, source_classifier_input_size) # Extract source_qa_model weights source_state_dict = {} for k, v in state_dict.items(): if k.startswith('source_qa_model.'): new_key = k.replace('source_qa_model.', '') source_state_dict[new_key] = v # Load weights missing, unexpected = source_qa_model.load_state_dict(source_state_dict, strict=False) print(f"\nLoaded source_qa_model: missing={len(missing)}, unexpected={len(unexpected)}") print("\n✅ SOURCE_QA_MODEL CREATED AND LOADED!") print("Now the full model will work correctly!")