File size: 5,347 Bytes
bdc2b78 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | """
ContextFlow RL Model Inference Example
This script demonstrates how to load the trained checkpoint and make predictions.
"""
import pickle
import numpy as np
import sys
import os
# Add current directory to path
sys.path.insert(0, os.path.dirname(__file__))
from feature_extractor import FeatureExtractor
# Doubt action labels (10 actions)
DOUBT_ACTIONS = [
"what_is_backpropagation",
"why_gradient_descent",
"how_overfitting_works",
"explain_regularization",
"what_loss_function",
"how_optimization_works",
"explain_learning_rate",
"what_regularization",
"how_batch_norm_works",
"explain_softmax"
]
class DoubtPredictor:
"""Simple doubt predictor using the trained Q-network"""
def __init__(self, checkpoint_path: str):
self.extractor = FeatureExtractor()
# Load checkpoint
with open(checkpoint_path, 'rb') as f:
self.checkpoint = pickle.load(f)
print(f"Loaded checkpoint v{self.checkpoint.policy_version}")
print(f"Training samples: {self.checkpoint.training_stats.get('total_samples', 'N/A')}")
def extract_state(self, **kwargs) -> np.ndarray:
"""Extract state vector from input features"""
return self.extractor.extract_state(**kwargs)
def predict(self, state: np.ndarray) -> dict:
"""
Predict doubt actions from state
Returns:
dict with predicted actions and Q-values
"""
# Simple linear approximation since we have Q-network weights
q_weights = self.checkpoint.q_network_weights
# Extract key weights (simplified)
if 'layer1.weight' in q_weights:
w1 = q_weights['layer1.weight']
b1 = q_weights['layer1.bias']
w2 = q_weights['layer2.weight']
b2 = q_weights['layer2.bias']
w3 = q_weights['output.weight']
b3 = q_weights['output.bias']
# Forward pass
h1 = np.maximum(np.dot(state, w1.T) + b1, 0) # ReLU
h2 = np.maximum(np.dot(h1, w2.T) + b2, 0) # ReLU
q_values = np.dot(h2, w3.T) + b3
else:
# Fallback: random predictions
q_values = np.random.randn(10) * 0.5
# Get top 3 predictions
top_indices = np.argsort(q_values)[::-1][:3]
return {
'predicted_doubt': DOUBT_ACTIONS[top_indices[0]],
'confidence': float(q_values[top_indices[0]]),
'top_predictions': [
{
'action': DOUBT_ACTIONS[i],
'q_value': float(q_values[i])
}
for i in top_indices
]
}
def example_inference():
"""Run example inferences"""
checkpoint_path = 'checkpoint.pkl'
if not os.path.exists(checkpoint_path):
print(f"Checkpoint not found: {checkpoint_path}")
print("Download from: https://huggingface.co/namish10/contextflow-rl")
return
predictor = DoubtPredictor(checkpoint_path)
print("\n" + "="*60)
print("EXAMPLE INFERENCES")
print("="*60)
# Example 1: Beginner ML student
print("\n[Scenario 1: Beginner ML student]")
state1 = predictor.extract_state(
topic="neural networks",
progress=0.3,
confusion_signals={
'mouse_hesitation': 3.0,
'scroll_reversals': 6,
'time_on_page': 45,
'back_button': 3,
'copy_attempts': 1
},
gesture_signals={
'pinch': 2,
'point': 5
},
time_spent=120
)
result1 = predictor.predict(state1)
print(f" Predicted doubt: {result1['predicted_doubt']}")
print(f" Q-value: {result1['confidence']:.4f}")
# Example 2: Advanced learner struggling with regularization
print("\n[Scenario 2: Advanced learner, high confusion signals]")
state2 = predictor.extract_state(
topic="deep learning",
progress=0.7,
confusion_signals={
'mouse_hesitation': 4.5,
'scroll_reversals': 8,
'time_on_page': 280,
'back_button': 5,
'copy_attempts': 2,
'search_usage': 3
},
gesture_signals={
'pinch': 8,
'swipe_left': 4,
'point': 10
},
time_spent=600
)
result2 = predictor.predict(state2)
print(f" Predicted doubt: {result2['predicted_doubt']}")
print(f" Q-value: {result2['confidence']:.4f}")
# Example 3: Quick learner, low confusion
print("\n[Scenario 3: Quick learner, low confusion]")
state3 = predictor.extract_state(
topic="python programming",
progress=0.9,
confusion_signals={
'mouse_hesitation': 0.5,
'scroll_reversals': 1,
'time_on_page': 20,
'back_button': 0
},
gesture_signals={
'swipe_down': 5,
'point': 3
},
time_spent=60
)
result3 = predictor.predict(state3)
print(f" Predicted doubt: {result3['predicted_doubt']}")
print(f" Q-value: {result3['confidence']:.4f}")
print("\n" + "="*60)
if __name__ == "__main__":
example_inference()
|