File size: 5,347 Bytes
bdc2b78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
ContextFlow RL Model Inference Example

This script demonstrates how to load the trained checkpoint and make predictions.
"""

import pickle
import numpy as np
import sys
import os

# Add current directory to path
sys.path.insert(0, os.path.dirname(__file__))

from feature_extractor import FeatureExtractor


# Doubt action labels (10 actions)
DOUBT_ACTIONS = [
    "what_is_backpropagation",
    "why_gradient_descent",
    "how_overfitting_works",
    "explain_regularization", 
    "what_loss_function",
    "how_optimization_works",
    "explain_learning_rate",
    "what_regularization",
    "how_batch_norm_works",
    "explain_softmax"
]


class DoubtPredictor:
    """Simple doubt predictor using the trained Q-network"""
    
    def __init__(self, checkpoint_path: str):
        self.extractor = FeatureExtractor()
        
        # Load checkpoint
        with open(checkpoint_path, 'rb') as f:
            self.checkpoint = pickle.load(f)
        
        print(f"Loaded checkpoint v{self.checkpoint.policy_version}")
        print(f"Training samples: {self.checkpoint.training_stats.get('total_samples', 'N/A')}")
    
    def extract_state(self, **kwargs) -> np.ndarray:
        """Extract state vector from input features"""
        return self.extractor.extract_state(**kwargs)
    
    def predict(self, state: np.ndarray) -> dict:
        """
        Predict doubt actions from state
        
        Returns:
            dict with predicted actions and Q-values
        """
        # Simple linear approximation since we have Q-network weights
        q_weights = self.checkpoint.q_network_weights
        
        # Extract key weights (simplified)
        if 'layer1.weight' in q_weights:
            w1 = q_weights['layer1.weight']
            b1 = q_weights['layer1.bias']
            w2 = q_weights['layer2.weight']
            b2 = q_weights['layer2.bias']
            w3 = q_weights['output.weight']
            b3 = q_weights['output.bias']
            
            # Forward pass
            h1 = np.maximum(np.dot(state, w1.T) + b1, 0)  # ReLU
            h2 = np.maximum(np.dot(h1, w2.T) + b2, 0)  # ReLU
            q_values = np.dot(h2, w3.T) + b3
        else:
            # Fallback: random predictions
            q_values = np.random.randn(10) * 0.5
        
        # Get top 3 predictions
        top_indices = np.argsort(q_values)[::-1][:3]
        
        return {
            'predicted_doubt': DOUBT_ACTIONS[top_indices[0]],
            'confidence': float(q_values[top_indices[0]]),
            'top_predictions': [
                {
                    'action': DOUBT_ACTIONS[i],
                    'q_value': float(q_values[i])
                }
                for i in top_indices
            ]
        }


def example_inference():
    """Run example inferences"""
    checkpoint_path = 'checkpoint.pkl'
    
    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint not found: {checkpoint_path}")
        print("Download from: https://huggingface.co/namish10/contextflow-rl")
        return
    
    predictor = DoubtPredictor(checkpoint_path)
    
    print("\n" + "="*60)
    print("EXAMPLE INFERENCES")
    print("="*60)
    
    # Example 1: Beginner ML student
    print("\n[Scenario 1: Beginner ML student]")
    state1 = predictor.extract_state(
        topic="neural networks",
        progress=0.3,
        confusion_signals={
            'mouse_hesitation': 3.0,
            'scroll_reversals': 6,
            'time_on_page': 45,
            'back_button': 3,
            'copy_attempts': 1
        },
        gesture_signals={
            'pinch': 2,
            'point': 5
        },
        time_spent=120
    )
    result1 = predictor.predict(state1)
    print(f"  Predicted doubt: {result1['predicted_doubt']}")
    print(f"  Q-value: {result1['confidence']:.4f}")
    
    # Example 2: Advanced learner struggling with regularization
    print("\n[Scenario 2: Advanced learner, high confusion signals]")
    state2 = predictor.extract_state(
        topic="deep learning",
        progress=0.7,
        confusion_signals={
            'mouse_hesitation': 4.5,
            'scroll_reversals': 8,
            'time_on_page': 280,
            'back_button': 5,
            'copy_attempts': 2,
            'search_usage': 3
        },
        gesture_signals={
            'pinch': 8,
            'swipe_left': 4,
            'point': 10
        },
        time_spent=600
    )
    result2 = predictor.predict(state2)
    print(f"  Predicted doubt: {result2['predicted_doubt']}")
    print(f"  Q-value: {result2['confidence']:.4f}")
    
    # Example 3: Quick learner, low confusion
    print("\n[Scenario 3: Quick learner, low confusion]")
    state3 = predictor.extract_state(
        topic="python programming",
        progress=0.9,
        confusion_signals={
            'mouse_hesitation': 0.5,
            'scroll_reversals': 1,
            'time_on_page': 20,
            'back_button': 0
        },
        gesture_signals={
            'swipe_down': 5,
            'point': 3
        },
        time_spent=60
    )
    result3 = predictor.predict(state3)
    print(f"  Predicted doubt: {result3['predicted_doubt']}")
    print(f"  Q-value: {result3['confidence']:.4f}")
    
    print("\n" + "="*60)


if __name__ == "__main__":
    example_inference()