ConceptFrameMet / inference.py
nixie1981's picture
Upload folder using huggingface_hub
1b12abd verified
"""
Simple inference script for ConceptFrameMet model
"""
import torch
from transformers import RobertaTokenizer, RobertaModel
import json
import argparse
def load_model(model_path):
"""Load the ConceptFrameMet model"""
# Load tokenizer
tokenizer = RobertaTokenizer.from_pretrained(model_path)
# Load model weights
model_weights = torch.load(f"{model_path}/pytorch_model.bin", map_location='cpu')
# Load config
with open(f"{model_path}/config.json", 'r') as f:
config = json.load(f)
print(f"✓ Model loaded from {model_path}")
print(f" Model type: {config.get('model_type', 'roberta')}")
return tokenizer, model_weights, config
def predict_metaphor(sentence, target_word, model_path, device='cpu'):
"""
Predict if a target word is metaphorical in the given sentence
Args:
sentence: Input sentence
target_word: Target word to analyze
model_path: Path to model directory
device: Device to run on ('cpu' or 'cuda')
Returns:
Dictionary with predictions
"""
tokenizer, model_weights, config = load_model(model_path)
# Tokenize input
inputs = tokenizer(
sentence,
max_length=150,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Find target word positions
target_tokens = tokenizer.tokenize(target_word)
sentence_tokens = tokenizer.tokenize(sentence)
target_positions = []
for i in range(len(sentence_tokens) - len(target_tokens) + 1):
if sentence_tokens[i:i+len(target_tokens)] == target_tokens:
# +1 for CLS token
target_positions = list(range(i+1, i+1+len(target_tokens)))
break
if not target_positions:
return {
"error": "Target word not found in sentence",
"sentence": sentence,
"target_word": target_word
}
# Create target mask
target_mask = torch.zeros_like(inputs['input_ids'], dtype=torch.float)
for pos in target_positions:
if pos < target_mask.size(1):
target_mask[0, pos] = 1.0
print(f"\n{'='*60}")
print(f"Sentence: {sentence}")
print(f"Target: {target_word}")
print(f"Target positions: {target_positions}")
print(f"{'='*60}\n")
# For now, return basic info
# Full inference requires loading the complete model architecture
return {
"sentence": sentence,
"target_word": target_word,
"target_positions": target_positions,
"message": "Model loaded successfully. Full inference requires frame and source models.",
"note": "This is a placeholder. Integrate with modeling_conceptframemet.py for full predictions."
}
def main():
parser = argparse.ArgumentParser(description='ConceptFrameMet Inference')
parser.add_argument('--model_path', type=str, required=True, help='Path to model directory')
parser.add_argument('--sentence', type=str, required=True, help='Input sentence')
parser.add_argument('--target', type=str, required=True, help='Target word')
parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda'], help='Device to use')
args = parser.parse_args()
result = predict_metaphor(
sentence=args.sentence,
target_word=args.target,
model_path=args.model_path,
device=args.device
)
print("\nResult:")
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()