| """ |
| Simple inference script for ConceptFrameMet model |
| """ |
|
|
| import torch |
| from transformers import RobertaTokenizer, RobertaModel |
| import json |
| import argparse |
|
|
|
|
| def load_model(model_path): |
| """Load the ConceptFrameMet model""" |
| |
| |
| tokenizer = RobertaTokenizer.from_pretrained(model_path) |
| |
| |
| model_weights = torch.load(f"{model_path}/pytorch_model.bin", map_location='cpu') |
| |
| |
| with open(f"{model_path}/config.json", 'r') as f: |
| config = json.load(f) |
| |
| print(f"✓ Model loaded from {model_path}") |
| print(f" Model type: {config.get('model_type', 'roberta')}") |
| |
| return tokenizer, model_weights, config |
|
|
|
|
| def predict_metaphor(sentence, target_word, model_path, device='cpu'): |
| """ |
| Predict if a target word is metaphorical in the given sentence |
| |
| Args: |
| sentence: Input sentence |
| target_word: Target word to analyze |
| model_path: Path to model directory |
| device: Device to run on ('cpu' or 'cuda') |
| |
| Returns: |
| Dictionary with predictions |
| """ |
| |
| tokenizer, model_weights, config = load_model(model_path) |
| |
| |
| inputs = tokenizer( |
| sentence, |
| max_length=150, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
| |
| |
| target_tokens = tokenizer.tokenize(target_word) |
| sentence_tokens = tokenizer.tokenize(sentence) |
| |
| target_positions = [] |
| for i in range(len(sentence_tokens) - len(target_tokens) + 1): |
| if sentence_tokens[i:i+len(target_tokens)] == target_tokens: |
| |
| target_positions = list(range(i+1, i+1+len(target_tokens))) |
| break |
| |
| if not target_positions: |
| return { |
| "error": "Target word not found in sentence", |
| "sentence": sentence, |
| "target_word": target_word |
| } |
| |
| |
| target_mask = torch.zeros_like(inputs['input_ids'], dtype=torch.float) |
| for pos in target_positions: |
| if pos < target_mask.size(1): |
| target_mask[0, pos] = 1.0 |
| |
| print(f"\n{'='*60}") |
| print(f"Sentence: {sentence}") |
| print(f"Target: {target_word}") |
| print(f"Target positions: {target_positions}") |
| print(f"{'='*60}\n") |
| |
| |
| |
| return { |
| "sentence": sentence, |
| "target_word": target_word, |
| "target_positions": target_positions, |
| "message": "Model loaded successfully. Full inference requires frame and source models.", |
| "note": "This is a placeholder. Integrate with modeling_conceptframemet.py for full predictions." |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='ConceptFrameMet Inference') |
| parser.add_argument('--model_path', type=str, required=True, help='Path to model directory') |
| parser.add_argument('--sentence', type=str, required=True, help='Input sentence') |
| parser.add_argument('--target', type=str, required=True, help='Target word') |
| parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda'], help='Device to use') |
| |
| args = parser.parse_args() |
| |
| result = predict_metaphor( |
| sentence=args.sentence, |
| target_word=args.target, |
| model_path=args.model_path, |
| device=args.device |
| ) |
| |
| print("\nResult:") |
| print(json.dumps(result, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|