File size: 3,588 Bytes
1b12abd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Simple inference script for ConceptFrameMet model
"""

import torch
from transformers import RobertaTokenizer, RobertaModel
import json
import argparse


def load_model(model_path):
    """Load the ConceptFrameMet model"""
    
    # Load tokenizer
    tokenizer = RobertaTokenizer.from_pretrained(model_path)
    
    # Load model weights
    model_weights = torch.load(f"{model_path}/pytorch_model.bin", map_location='cpu')
    
    # Load config
    with open(f"{model_path}/config.json", 'r') as f:
        config = json.load(f)
    
    print(f"✓ Model loaded from {model_path}")
    print(f"  Model type: {config.get('model_type', 'roberta')}")
    
    return tokenizer, model_weights, config


def predict_metaphor(sentence, target_word, model_path, device='cpu'):
    """
    Predict if a target word is metaphorical in the given sentence
    
    Args:
        sentence: Input sentence
        target_word: Target word to analyze
        model_path: Path to model directory
        device: Device to run on ('cpu' or 'cuda')
        
    Returns:
        Dictionary with predictions
    """
    
    tokenizer, model_weights, config = load_model(model_path)
    
    # Tokenize input
    inputs = tokenizer(
        sentence,
        max_length=150,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    # Find target word positions
    target_tokens = tokenizer.tokenize(target_word)
    sentence_tokens = tokenizer.tokenize(sentence)
    
    target_positions = []
    for i in range(len(sentence_tokens) - len(target_tokens) + 1):
        if sentence_tokens[i:i+len(target_tokens)] == target_tokens:
            # +1 for CLS token
            target_positions = list(range(i+1, i+1+len(target_tokens)))
            break
    
    if not target_positions:
        return {
            "error": "Target word not found in sentence",
            "sentence": sentence,
            "target_word": target_word
        }
    
    # Create target mask
    target_mask = torch.zeros_like(inputs['input_ids'], dtype=torch.float)
    for pos in target_positions:
        if pos < target_mask.size(1):
            target_mask[0, pos] = 1.0
    
    print(f"\n{'='*60}")
    print(f"Sentence: {sentence}")
    print(f"Target: {target_word}")
    print(f"Target positions: {target_positions}")
    print(f"{'='*60}\n")
    
    # For now, return basic info
    # Full inference requires loading the complete model architecture
    return {
        "sentence": sentence,
        "target_word": target_word,
        "target_positions": target_positions,
        "message": "Model loaded successfully. Full inference requires frame and source models.",
        "note": "This is a placeholder. Integrate with modeling_conceptframemet.py for full predictions."
    }


def main():
    parser = argparse.ArgumentParser(description='ConceptFrameMet Inference')
    parser.add_argument('--model_path', type=str, required=True, help='Path to model directory')
    parser.add_argument('--sentence', type=str, required=True, help='Input sentence')
    parser.add_argument('--target', type=str, required=True, help='Target word')
    parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda'], help='Device to use')
    
    args = parser.parse_args()
    
    result = predict_metaphor(
        sentence=args.sentence,
        target_word=args.target,
        model_path=args.model_path,
        device=args.device
    )
    
    print("\nResult:")
    print(json.dumps(result, indent=2))


if __name__ == "__main__":
    main()