ConceptFrameMet / modeling_conceptframemet.py
nixie1981's picture
Upload modeling_conceptframemet.py with huggingface_hub
b6384b1 verified
"""
Adaptive Source QA MelBERT with Configurable Blending Strategies
This model provides configurable approaches to incorporating source domain information:
FLAGS:
1. --source_blend_mode: 'additive' or 'replacement' (default: 'replacement')
- additive: enhanced = target + alpha * source (keeps target strength)
- replacement: blended = conf * source + (1-conf) * target (original approach)
2. --source_use_mode: 'metaphor_only' or 'all' (default: 'all')
- metaphor_only: Only use source for samples with high metaphor probability
- all: Use source for all samples
3. --source_alpha: float (default: 0.3) - scaling factor for additive mode
4. --metaphor_threshold: float (default: 0.5) - threshold for metaphor-only mode
Architecture:
- CONTEXT: target_word in full sentence β†’ encoder 1 β†’ target_context_embedding
- SOURCE: [SEP] sentence [SEP] target [SEP] β†’ QA model β†’ predict source + confidence
- ISOLATED: isolated target β†’ encoder 2 β†’ target_embedding
- BLEND: Configurable (additive or replacement)
- FILTER: Configurable (metaphor-only or all)
- MIP: [enhanced_embedding, target_context_embedding]
- SPV: [pooled, enhanced_embedding] or [pooled, target_context_embedding]
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaptiveSourceQAMelBert(nn.Module):
"""MelBERT with configurable source domain blending strategies"""
def __init__(self, args, Model, config, Source_QA_Model,
source_qa_tokenizer, melbert_tokenizer, num_labels=2):
"""
Initialize the model with configurable flags
Args:
args: Configuration arguments with:
- source_blend_mode: 'additive' or 'replacement'
- source_use_mode: 'metaphor_only' or 'all'
- source_alpha: scaling factor for additive mode
- metaphor_threshold: threshold for metaphor-only mode
Model: MelBert encoder (RoBERTa/BERT)
config: Model configuration
Source_QA_Model: QA-style model to predict source domain
source_qa_tokenizer: Tokenizer for QA model
melbert_tokenizer: Tokenizer for MelBert
num_labels: Number of metaphor classes (2: literal/metaphorical)
"""
super(AdaptiveSourceQAMelBert, self).__init__()
self.num_labels = num_labels
self.encoder = Model
# FIX: Resize token_type_embeddings to match training (type_vocab_size=4)
if hasattr(self.encoder, 'embeddings') and hasattr(self.encoder.embeddings, 'token_type_embeddings'):
if self.encoder.embeddings.token_type_embeddings.weight.shape[0] != 4:
old_embeddings = self.encoder.embeddings.token_type_embeddings
new_embeddings = nn.Embedding(4, old_embeddings.embedding_dim)
new_embeddings.weight.data[0] = old_embeddings.weight.data[0]
new_embeddings.weight.data[1:].normal_(mean=0.0, std=config.initializer_range)
self.encoder.embeddings.token_type_embeddings = new_embeddings
if hasattr(self.encoder, 'config'):
self.encoder.config.type_vocab_size = 4
self.source_qa_model = Source_QA_Model
self.source_qa_tokenizer = source_qa_tokenizer
self.melbert_tokenizer = melbert_tokenizer
self.config = config
self.dropout = nn.Dropout(args.drop_ratio)
self.args = args
# Configuration flags with defaults
self.source_blend_mode = getattr(args, 'source_blend_mode', 'replacement')
self.source_use_mode = getattr(args, 'source_use_mode', 'all')
self.source_alpha = getattr(args, 'source_alpha', 0.3)
self.metaphor_threshold = getattr(args, 'metaphor_threshold', 0.5)
# Freeze or unfreeze source QA model (only if it exists)
if self.source_qa_model is not None:
if not getattr(args, 'unfreeze_source_qa', False):
for param in self.source_qa_model.parameters():
param.requires_grad = False
else:
for param in self.source_qa_model.parameters():
param.requires_grad = True
# Load source labels
self.source_id2label = {}
try:
import json
import os
# Try multiple paths
possible_paths = [
'source_labels.json', # Same directory as model file
'source_finder/source_labels.json', # Original location
os.path.join(os.path.dirname(__file__), 'source_labels.json'), # Next to this file
]
for path in possible_paths:
try:
with open(path, 'r') as f:
source_label2id = json.load(f)
self.source_id2label = {v: k for k, v in source_label2id.items()}
print(f"βœ“ Loaded {len(self.source_id2label)} source domain labels from {path}")
break
except:
continue
if not self.source_id2label:
print(f"❌ Warning: Could not load source labels from any location")
except Exception as e:
print(f"❌ Warning: Could not load source labels: {e}")
# SPV and MIP linear layers
self.SPV_linear = nn.Linear(config.hidden_size * 2, args.classifier_hidden)
self.MIP_linear = nn.Linear(config.hidden_size * 2, args.classifier_hidden)
self.classifier = nn.Linear(args.classifier_hidden * 2, num_labels)
self._init_weights(self.SPV_linear)
self._init_weights(self.MIP_linear)
self._init_weights(self.classifier)
self.logsoftmax = nn.LogSoftmax(dim=1)
# Print configuration
print(f"\n{'='*80}")
print(f"βœ“ AdaptiveSourceQAMelBert initialized")
print(f" - Blend Mode: {self.source_blend_mode.upper()}")
if self.source_blend_mode == 'additive':
print(f" β†’ enhanced = target + {self.source_alpha} * source")
else:
print(f" β†’ blended = conf * source + (1-conf) * target")
print(f" - Use Mode: {self.source_use_mode.upper()}")
if self.source_use_mode == 'metaphor_only':
print(f" β†’ Only use source when metaphor_score > {self.metaphor_threshold}")
else:
print(f" β†’ Use source for all samples")
print(f"{'='*80}\n")
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def predict_source_and_embeddings(self, input_ids, target_mask, attention_mask,
input_ids_2, target_mask_2, attention_mask_2):
"""
Predict source domain and get source/target embeddings
Returns:
source_embeddings: [batch_size, hidden_size]
target_embeddings: [batch_size, hidden_size]
confidences: [batch_size] - confidence scores
"""
batch_size = input_ids.size(0)
# If no source QA model, load from checkpoint and use embeddings from there
if self.source_qa_model is None:
# Use isolated target embeddings as source (will be loaded from checkpoint)
target_outputs_2 = self.encoder(input_ids_2, attention_mask=attention_mask_2)
target_sequence_output_2 = target_outputs_2[0]
target_output_2 = target_sequence_output_2 * target_mask_2.unsqueeze(2)
if self.args.small_mean:
target_embeddings_2 = target_output_2.mean(1)
else:
target_embeddings_2 = target_output_2.sum(dim=1) / target_mask_2.sum(-1, keepdim=True)
# Use same embedding for source (will blend based on checkpoint source_qa_model)
source_embeddings = target_embeddings_2
confidences = torch.ones(batch_size).to(input_ids.device) * 0.5
return source_embeddings, target_embeddings_2, confidences
# Original logic with source QA model
# 1. Decode sentences and extract target words
sentences = []
target_words = []
for i in range(batch_size):
sentence = self.melbert_tokenizer.decode(input_ids[i], skip_special_tokens=True)
target_positions = target_mask[i].nonzero(as_tuple=True)[0]
if len(target_positions) > 0:
target_tokens = input_ids[i][target_positions]
target_word = self.melbert_tokenizer.decode(target_tokens, skip_special_tokens=True)
else:
target_word = "unknown"
sentences.append(sentence)
target_words.append(target_word)
# 2. Format QA input and predict source
with torch.no_grad():
qa_inputs = self.source_qa_tokenizer(
sentences,
target_words,
max_length=self.args.max_seq_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
qa_inputs = {k: v.to(input_ids.device) for k, v in qa_inputs.items()}
# If source model is FrameAwareSourcePredictor, also pass frame inputs
# (frame inputs are the same as source inputs for this use case)
if hasattr(self.source_qa_model, 'frame_finder'):
qa_inputs['frame_input_ids'] = qa_inputs['input_ids']
qa_inputs['frame_attention_mask'] = qa_inputs['attention_mask']
# 3. Get source predictions with confidence
qa_outputs = self.source_qa_model(**qa_inputs)
source_logits = qa_outputs.logits
source_probs = torch.softmax(source_logits, dim=-1)
predicted_source_ids = torch.argmax(source_logits, dim=-1)
# Get confidence scores
confidences = source_probs.gather(1, predicted_source_ids.unsqueeze(1)).squeeze(1)
# Map to source words
with torch.no_grad():
predicted_sources = [self.source_id2label.get(sid.item(), "UNKNOWN")
for sid in predicted_source_ids]
# 4. Encode predicted source words
source_inputs = self.melbert_tokenizer(
predicted_sources,
max_length=self.args.max_seq_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
source_inputs = {k: v.to(input_ids.device) for k, v in source_inputs.items()}
source_target_mask = (source_inputs['input_ids'] != self.melbert_tokenizer.pad_token_id).float()
source_outputs = self.encoder(
source_inputs['input_ids'],
attention_mask=source_inputs['attention_mask']
)
source_sequence_output = source_outputs[0]
source_target_output = source_sequence_output * source_target_mask.unsqueeze(2)
if self.args.small_mean:
source_embeddings = source_target_output.mean(1)
else:
source_embeddings = source_target_output.sum(dim=1) / source_target_mask.sum(-1, keepdim=True)
# 5. Encode original isolated target words
target_outputs_2 = self.encoder(
input_ids_2,
attention_mask=attention_mask_2
)
target_sequence_output_2 = target_outputs_2[0]
target_output_2 = target_sequence_output_2 * target_mask_2.unsqueeze(2)
if self.args.small_mean:
target_embeddings_2 = target_output_2.mean(1)
else:
target_embeddings_2 = target_output_2.sum(dim=1) / target_mask_2.sum(-1, keepdim=True)
return source_embeddings, target_embeddings_2, confidences
def blend_embeddings(self, source_embeddings, target_embeddings, confidences):
"""
Blend source and target embeddings based on configuration
Args:
source_embeddings: [batch_size, hidden_size]
target_embeddings: [batch_size, hidden_size]
confidences: [batch_size]
Returns:
blended_embeddings: [batch_size, hidden_size]
"""
confidence_weights = confidences.unsqueeze(1)
if self.source_blend_mode == 'additive':
# ADDITIVE: enhanced = target + alpha * source
# Keeps target strength, adds source as enhancement
enhanced = target_embeddings + self.source_alpha * confidence_weights * source_embeddings
return enhanced
else:
# REPLACEMENT: blended = conf * source + (1-conf) * target
# Original soft confidence approach
blended = confidence_weights * source_embeddings + (1 - confidence_weights) * target_embeddings
return blended
def forward(
self,
input_ids,
input_ids_2,
target_mask,
target_mask_2,
attention_mask_2,
token_type_ids=None,
attention_mask=None,
labels=None,
head_mask=None,
input_with_mask_ids=None
):
"""
Forward pass with configurable source blending
"""
# ===== ENCODER 1: Target in context =====
outputs = self.encoder(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
head_mask=head_mask,
)
sequence_output = outputs[0]
pooled_output = outputs[1]
# Get target output with target mask
target_output = sequence_output * target_mask.unsqueeze(2)
target_output = self.dropout(target_output)
pooled_output = self.dropout(pooled_output)
if self.args.small_mean:
target_output = target_output.mean(1)
else:
target_output = target_output.sum(dim=1) / target_mask.sum(-1, keepdim=True)
# ===== ENCODER 2: Get source and target embeddings =====
source_embeddings, target_embeddings_2, confidences = self.predict_source_and_embeddings(
input_ids, target_mask, attention_mask,
input_ids_2, target_mask_2, attention_mask_2
)
# ===== METAPHOR-ONLY FILTERING (if enabled) =====
if self.source_use_mode == 'metaphor_only':
# Get preliminary metaphor score
# Use simple heuristic based on target context
prelim_features = torch.cat([pooled_output, target_output], dim=1)
prelim_hidden = self.SPV_linear(prelim_features)
prelim_logits = self.classifier(torch.cat([prelim_hidden, prelim_hidden], dim=1))
prelim_probs = torch.exp(self.logsoftmax(prelim_logits))
metaphor_scores = prelim_probs[:, 1] # Probability of metaphor class
# Only use source for samples with high metaphor probability
use_source_mask = (metaphor_scores > self.metaphor_threshold).float().unsqueeze(1)
else:
# Use source for all samples
use_source_mask = torch.ones(source_embeddings.size(0), 1).to(source_embeddings.device)
# ===== BLEND: Apply configured blending strategy =====
blended_embedding = self.blend_embeddings(source_embeddings, target_embeddings_2, confidences)
# Apply metaphor-only mask
final_embedding = use_source_mask * blended_embedding + (1 - use_source_mask) * target_embeddings_2
final_embedding = self.dropout(final_embedding)
# ===== SPV and MIP =====
if self.args.spv_isolate:
SPV_hidden = self.SPV_linear(torch.cat([pooled_output, final_embedding], dim=1))
else:
SPV_hidden = self.SPV_linear(torch.cat([pooled_output, target_output], dim=1))
MIP_hidden = self.MIP_linear(torch.cat([final_embedding, target_output], dim=1))
# Final classification
logits = self.classifier(self.dropout(torch.cat([SPV_hidden, MIP_hidden], dim=1)))
logits = self.logsoftmax(logits)
if labels is not None:
loss_fct = nn.NLLLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
return logits