""" Simple inference script for ConceptFrameMet model """ import torch from transformers import RobertaTokenizer, RobertaModel import json import argparse def load_model(model_path): """Load the ConceptFrameMet model""" # Load tokenizer tokenizer = RobertaTokenizer.from_pretrained(model_path) # Load model weights model_weights = torch.load(f"{model_path}/pytorch_model.bin", map_location='cpu') # Load config with open(f"{model_path}/config.json", 'r') as f: config = json.load(f) print(f"✓ Model loaded from {model_path}") print(f" Model type: {config.get('model_type', 'roberta')}") return tokenizer, model_weights, config def predict_metaphor(sentence, target_word, model_path, device='cpu'): """ Predict if a target word is metaphorical in the given sentence Args: sentence: Input sentence target_word: Target word to analyze model_path: Path to model directory device: Device to run on ('cpu' or 'cuda') Returns: Dictionary with predictions """ tokenizer, model_weights, config = load_model(model_path) # Tokenize input inputs = tokenizer( sentence, max_length=150, padding='max_length', truncation=True, return_tensors='pt' ) # Find target word positions target_tokens = tokenizer.tokenize(target_word) sentence_tokens = tokenizer.tokenize(sentence) target_positions = [] for i in range(len(sentence_tokens) - len(target_tokens) + 1): if sentence_tokens[i:i+len(target_tokens)] == target_tokens: # +1 for CLS token target_positions = list(range(i+1, i+1+len(target_tokens))) break if not target_positions: return { "error": "Target word not found in sentence", "sentence": sentence, "target_word": target_word } # Create target mask target_mask = torch.zeros_like(inputs['input_ids'], dtype=torch.float) for pos in target_positions: if pos < target_mask.size(1): target_mask[0, pos] = 1.0 print(f"\n{'='*60}") print(f"Sentence: {sentence}") print(f"Target: {target_word}") print(f"Target positions: {target_positions}") print(f"{'='*60}\n") # For now, return basic info # Full inference requires loading the complete model architecture return { "sentence": sentence, "target_word": target_word, "target_positions": target_positions, "message": "Model loaded successfully. Full inference requires frame and source models.", "note": "This is a placeholder. Integrate with modeling_conceptframemet.py for full predictions." } def main(): parser = argparse.ArgumentParser(description='ConceptFrameMet Inference') parser.add_argument('--model_path', type=str, required=True, help='Path to model directory') parser.add_argument('--sentence', type=str, required=True, help='Input sentence') parser.add_argument('--target', type=str, required=True, help='Target word') parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda'], help='Device to use') args = parser.parse_args() result = predict_metaphor( sentence=args.sentence, target_word=args.target, model_path=args.model_path, device=args.device ) print("\nResult:") print(json.dumps(result, indent=2)) if __name__ == "__main__": main()