| """ |
| Complete working script to load ConceptFrameMet from HuggingFace with ALL weights. |
| This properly reconstructs the source_qa_model from checkpoint weights. |
| """ |
|
|
| from huggingface_hub import hf_hub_download |
| import torch |
| import torch.nn as nn |
| from transformers import RobertaModel, RobertaTokenizer, RobertaForSequenceClassification, RobertaConfig |
| import sys |
| import os |
|
|
| |
| print("Downloading from HuggingFace...") |
| weights_path = hf_hub_download("nixie1981/ConceptFrameMet", "pytorch_model.bin") |
| labels_path = hf_hub_download("nixie1981/ConceptFrameMet", "source_labels.json") |
|
|
| |
| print("Loading checkpoint...") |
| state_dict = torch.load(weights_path, map_location='cpu') |
|
|
| print(f"Checkpoint has {len(state_dict)} keys") |
|
|
| |
| has_source_qa = any(k.startswith('source_qa_model.') for k in state_dict.keys()) |
| print(f"Has source_qa_model weights: {has_source_qa}") |
|
|
| if has_source_qa: |
| |
| source_keys = [k for k in state_dict.keys() if k.startswith('source_qa_model.')] |
| print(f"Source QA model has {len(source_keys)} keys") |
| |
| |
| |
| has_frame_finder = any('frame_finder' in k for k in source_keys) |
| has_source_classifier = any('source_classifier' in k for k in source_keys) |
| |
| print(f" - Has frame_finder: {has_frame_finder}") |
| print(f" - Has source_classifier: {has_source_classifier}") |
| |
| if has_frame_finder and has_source_classifier: |
| print("\nThis is a TrueMultiTaskModel (frame + source)!") |
| print("Creating source_qa_model structure...") |
| |
| |
| frame_weight_key = 'source_qa_model.frame_finder.classifier.out_proj.weight' |
| source_weight_key = 'source_qa_model.source_classifier.weight' |
| |
| num_frames = state_dict[frame_weight_key].shape[0] if frame_weight_key in state_dict else None |
| num_sources = state_dict[source_weight_key].shape[0] if source_weight_key in state_dict else None |
| |
| print(f" - num_frames: {num_frames}") |
| print(f" - num_sources: {num_sources}") |
| |
| if num_frames and num_sources: |
| |
| config = RobertaConfig.from_pretrained('roberta-base') |
| |
| |
| source_classifier_weight = state_dict.get('source_qa_model.source_classifier.weight') |
| source_classifier_input_size = source_classifier_weight.shape[1] if source_classifier_weight is not None else None |
| |
| print(f" - source_classifier input size: {source_classifier_input_size}") |
| |
| class TrueMultiTaskModel(nn.Module): |
| def __init__(self, config, num_frames, num_sources, source_input_size): |
| super().__init__() |
| self.config = config |
| self.num_frames = num_frames |
| self.num_sources = num_sources |
| |
| self.roberta = RobertaModel(config) |
| self.frame_finder = RobertaForSequenceClassification(config) |
| self.frame_finder.classifier = nn.Linear(config.hidden_size, num_frames) |
| |
| |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.source_classifier = nn.Linear(source_input_size, num_sources) |
| |
| def forward(self, input_ids=None, attention_mask=None, |
| frame_input_ids=None, frame_attention_mask=None, **kwargs): |
| |
| frame_outputs = self.frame_finder(input_ids=frame_input_ids, |
| attention_mask=frame_attention_mask) |
| frame_logits = frame_outputs.logits |
| |
| |
| if input_ids is not None: |
| source_outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) |
| pooled_output = source_outputs.pooler_output |
| combined = torch.cat([pooled_output, frame_logits], dim=1) |
| combined = self.dropout(combined) |
| logits = self.source_classifier(combined) |
| |
| class Output: |
| pass |
| output = Output() |
| output.logits = logits |
| return output |
| |
| class Output: |
| pass |
| output = Output() |
| output.logits = frame_logits |
| return output |
| |
| |
| source_qa_model = TrueMultiTaskModel(config, num_frames, num_sources, source_classifier_input_size) |
| |
| |
| source_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.startswith('source_qa_model.'): |
| new_key = k.replace('source_qa_model.', '') |
| source_state_dict[new_key] = v |
| |
| |
| missing, unexpected = source_qa_model.load_state_dict(source_state_dict, strict=False) |
| print(f"\nLoaded source_qa_model: missing={len(missing)}, unexpected={len(unexpected)}") |
| |
| print("\n✅ SOURCE_QA_MODEL CREATED AND LOADED!") |
| print("Now the full model will work correctly!") |
|
|