Upload src/train.py with huggingface_hub
Browse files- src/train.py +570 -0
src/train.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for GazeInception-Lite model.
|
| 3 |
+
|
| 4 |
+
Trains both:
|
| 5 |
+
1. Single-eye model (89K params) - ultralight for constrained devices
|
| 6 |
+
2. Dual-eye model (137K params) - better accuracy with face context + lazy eye handling
|
| 7 |
+
|
| 8 |
+
Features:
|
| 9 |
+
- Gated Inception blocks with learned branch gating
|
| 10 |
+
- Coordinate Attention for spatial gaze awareness
|
| 11 |
+
- Synthetic data with: dark conditions, glasses, lazy eye, sensor noise
|
| 12 |
+
- TFLite conversion with full integer quantization
|
| 13 |
+
- Push to Hugging Face Hub
|
| 14 |
+
|
| 15 |
+
Based on:
|
| 16 |
+
- AGE framework (arxiv:2603.26945) - augmentation pipeline, multi-task approach
|
| 17 |
+
- Gated Compression Layers (arxiv:2303.08970) - gating mechanism
|
| 18 |
+
- iTracker (arxiv:1606.05814) - dual-eye + face architecture
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import json
|
| 23 |
+
import time
|
| 24 |
+
import numpy as np
|
| 25 |
+
import tensorflow as tf
|
| 26 |
+
from tensorflow import keras
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
|
| 29 |
+
# Import our modules
|
| 30 |
+
from model import build_gaze_inception_lite, build_dual_eye_model
|
| 31 |
+
from data_generator import SyntheticGazeDataGenerator, create_tf_dataset, create_single_eye_dataset
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def euclidean_distance_metric(y_true, y_pred):
|
| 35 |
+
"""Euclidean distance in normalized [0,1] coordinates."""
|
| 36 |
+
return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(y_true - y_pred), axis=-1)))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def screen_error_mm(y_true, y_pred, screen_w_mm=65.0, screen_h_mm=140.0):
|
| 40 |
+
"""Error in mm assuming typical phone screen (65mm x 140mm)."""
|
| 41 |
+
diff = y_true - y_pred
|
| 42 |
+
diff_mm = diff * tf.constant([screen_w_mm, screen_h_mm])
|
| 43 |
+
return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(diff_mm), axis=-1)))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def train_single_eye_model(train_data, val_data, epochs=100, batch_size=128,
|
| 47 |
+
output_dir='models/single_eye'):
|
| 48 |
+
"""Train the lightweight single-eye model."""
|
| 49 |
+
print("\n" + "="*60)
|
| 50 |
+
print("Training Single-Eye GazeInception-Lite Model")
|
| 51 |
+
print("="*60)
|
| 52 |
+
|
| 53 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
model = build_gaze_inception_lite(input_shape=(64, 64, 3), num_outputs=2)
|
| 56 |
+
|
| 57 |
+
# Cosine decay learning rate
|
| 58 |
+
lr_schedule = keras.optimizers.schedules.CosineDecay(
|
| 59 |
+
initial_learning_rate=1e-3,
|
| 60 |
+
decay_steps=epochs * (len(train_data['gaze']) * 2 // batch_size),
|
| 61 |
+
alpha=1e-5
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
|
| 65 |
+
|
| 66 |
+
model.compile(
|
| 67 |
+
optimizer=optimizer,
|
| 68 |
+
loss='mse',
|
| 69 |
+
metrics=[euclidean_distance_metric]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Create datasets
|
| 73 |
+
train_ds = create_single_eye_dataset(train_data, batch_size=batch_size, shuffle=True)
|
| 74 |
+
val_ds = create_single_eye_dataset(val_data, batch_size=batch_size, shuffle=False)
|
| 75 |
+
|
| 76 |
+
callbacks = [
|
| 77 |
+
keras.callbacks.ModelCheckpoint(
|
| 78 |
+
os.path.join(output_dir, 'best_model.keras'),
|
| 79 |
+
monitor='val_euclidean_distance_metric',
|
| 80 |
+
save_best_only=True, mode='min', verbose=1
|
| 81 |
+
),
|
| 82 |
+
keras.callbacks.ReduceLROnPlateau(
|
| 83 |
+
monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1
|
| 84 |
+
),
|
| 85 |
+
keras.callbacks.EarlyStopping(
|
| 86 |
+
monitor='val_euclidean_distance_metric', patience=20,
|
| 87 |
+
restore_best_weights=True, verbose=1
|
| 88 |
+
),
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
history = model.fit(
|
| 92 |
+
train_ds, validation_data=val_ds,
|
| 93 |
+
epochs=epochs, callbacks=callbacks, verbose=1
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Save final model
|
| 97 |
+
model.save(os.path.join(output_dir, 'final_model.keras'))
|
| 98 |
+
|
| 99 |
+
return model, history
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def train_dual_eye_model(train_data, val_data, epochs=100, batch_size=64,
|
| 103 |
+
output_dir='models/dual_eye'):
|
| 104 |
+
"""Train the dual-eye model with face context."""
|
| 105 |
+
print("\n" + "="*60)
|
| 106 |
+
print("Training Dual-Eye GazeInception-Lite Model")
|
| 107 |
+
print("="*60)
|
| 108 |
+
|
| 109 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 110 |
+
|
| 111 |
+
model = build_dual_eye_model(
|
| 112 |
+
eye_shape=(64, 64, 3), face_shape=(64, 64, 3), num_outputs=2
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
lr_schedule = keras.optimizers.schedules.CosineDecay(
|
| 116 |
+
initial_learning_rate=1e-3,
|
| 117 |
+
decay_steps=epochs * (len(train_data['gaze']) // batch_size),
|
| 118 |
+
alpha=1e-5
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
|
| 122 |
+
|
| 123 |
+
model.compile(
|
| 124 |
+
optimizer=optimizer,
|
| 125 |
+
loss='mse',
|
| 126 |
+
metrics=[euclidean_distance_metric]
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Create datasets
|
| 130 |
+
train_ds = create_tf_dataset(train_data, batch_size=batch_size, shuffle=True)
|
| 131 |
+
val_ds = create_tf_dataset(val_data, batch_size=batch_size, shuffle=False)
|
| 132 |
+
|
| 133 |
+
callbacks = [
|
| 134 |
+
keras.callbacks.ModelCheckpoint(
|
| 135 |
+
os.path.join(output_dir, 'best_model.keras'),
|
| 136 |
+
monitor='val_euclidean_distance_metric',
|
| 137 |
+
save_best_only=True, mode='min', verbose=1
|
| 138 |
+
),
|
| 139 |
+
keras.callbacks.ReduceLROnPlateau(
|
| 140 |
+
monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1
|
| 141 |
+
),
|
| 142 |
+
keras.callbacks.EarlyStopping(
|
| 143 |
+
monitor='val_euclidean_distance_metric', patience=20,
|
| 144 |
+
restore_best_weights=True, verbose=1
|
| 145 |
+
),
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
history = model.fit(
|
| 149 |
+
train_ds, validation_data=val_ds,
|
| 150 |
+
epochs=epochs, callbacks=callbacks, verbose=1
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
model.save(os.path.join(output_dir, 'final_model.keras'))
|
| 154 |
+
|
| 155 |
+
return model, history
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def convert_to_tflite(keras_model, output_path, quantize=True, test_data=None):
|
| 159 |
+
"""Convert Keras model to TFLite with optional quantization."""
|
| 160 |
+
print(f"\nConverting to TFLite: {output_path}")
|
| 161 |
+
|
| 162 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
| 163 |
+
|
| 164 |
+
# Optimization settings for mobile
|
| 165 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 166 |
+
|
| 167 |
+
if quantize and test_data is not None:
|
| 168 |
+
# Full integer quantization (fastest on mobile)
|
| 169 |
+
def representative_dataset_gen():
|
| 170 |
+
for i in range(min(200, len(test_data))):
|
| 171 |
+
yield [test_data[i:i+1].astype(np.float32)]
|
| 172 |
+
|
| 173 |
+
converter.representative_dataset = representative_dataset_gen
|
| 174 |
+
converter.target_spec.supported_ops = [
|
| 175 |
+
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
|
| 176 |
+
]
|
| 177 |
+
converter.inference_input_type = tf.uint8
|
| 178 |
+
converter.inference_output_type = tf.float32
|
| 179 |
+
print(" Using INT8 quantization for maximum mobile speed")
|
| 180 |
+
|
| 181 |
+
tflite_model = converter.convert()
|
| 182 |
+
|
| 183 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 184 |
+
with open(output_path, 'wb') as f:
|
| 185 |
+
f.write(tflite_model)
|
| 186 |
+
|
| 187 |
+
size_kb = len(tflite_model) / 1024
|
| 188 |
+
print(f" TFLite model size: {size_kb:.1f} KB")
|
| 189 |
+
|
| 190 |
+
return tflite_model
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def convert_dual_eye_to_tflite(keras_model, output_path, quantize=True, test_data=None):
|
| 194 |
+
"""Convert dual-eye model to TFLite."""
|
| 195 |
+
print(f"\nConverting dual-eye model to TFLite: {output_path}")
|
| 196 |
+
|
| 197 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
| 198 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 199 |
+
|
| 200 |
+
if quantize and test_data is not None:
|
| 201 |
+
def representative_dataset_gen():
|
| 202 |
+
for i in range(min(200, len(test_data['left_eye']))):
|
| 203 |
+
yield [
|
| 204 |
+
test_data['left_eye'][i:i+1].astype(np.float32),
|
| 205 |
+
test_data['right_eye'][i:i+1].astype(np.float32),
|
| 206 |
+
test_data['face'][i:i+1].astype(np.float32),
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
converter.representative_dataset = representative_dataset_gen
|
| 210 |
+
converter.target_spec.supported_ops = [
|
| 211 |
+
tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
|
| 212 |
+
tf.lite.OpsSet.TFLITE_BUILTINS, # fallback for unsupported ops
|
| 213 |
+
]
|
| 214 |
+
converter.inference_input_type = tf.uint8
|
| 215 |
+
converter.inference_output_type = tf.float32
|
| 216 |
+
print(" Using INT8 quantization for maximum mobile speed")
|
| 217 |
+
|
| 218 |
+
tflite_model = converter.convert()
|
| 219 |
+
|
| 220 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 221 |
+
with open(output_path, 'wb') as f:
|
| 222 |
+
f.write(tflite_model)
|
| 223 |
+
|
| 224 |
+
size_kb = len(tflite_model) / 1024
|
| 225 |
+
print(f" TFLite model size: {size_kb:.1f} KB")
|
| 226 |
+
|
| 227 |
+
return tflite_model
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def evaluate_tflite(tflite_path, test_inputs, test_labels, is_dual=False):
|
| 231 |
+
"""Evaluate TFLite model accuracy and inference speed."""
|
| 232 |
+
print(f"\nEvaluating TFLite model: {tflite_path}")
|
| 233 |
+
|
| 234 |
+
interpreter = tf.lite.Interpreter(model_path=tflite_path)
|
| 235 |
+
interpreter.allocate_tensors()
|
| 236 |
+
|
| 237 |
+
input_details = interpreter.get_input_details()
|
| 238 |
+
output_details = interpreter.get_output_details()
|
| 239 |
+
|
| 240 |
+
print(f" Input details: {[(d['name'], d['shape'], d['dtype']) for d in input_details]}")
|
| 241 |
+
print(f" Output details: {[(d['name'], d['shape'], d['dtype']) for d in output_details]}")
|
| 242 |
+
|
| 243 |
+
predictions = []
|
| 244 |
+
num_samples = min(500, len(test_labels))
|
| 245 |
+
|
| 246 |
+
# Warmup
|
| 247 |
+
for _ in range(5):
|
| 248 |
+
if is_dual:
|
| 249 |
+
for idx, detail in enumerate(input_details):
|
| 250 |
+
if detail['dtype'] == np.uint8:
|
| 251 |
+
data = (test_inputs[idx][0:1] * 255).astype(np.uint8)
|
| 252 |
+
else:
|
| 253 |
+
data = test_inputs[idx][0:1].astype(np.float32)
|
| 254 |
+
interpreter.set_tensor(detail['index'], data)
|
| 255 |
+
else:
|
| 256 |
+
if input_details[0]['dtype'] == np.uint8:
|
| 257 |
+
data = (test_inputs[0:1] * 255).astype(np.uint8)
|
| 258 |
+
else:
|
| 259 |
+
data = test_inputs[0:1].astype(np.float32)
|
| 260 |
+
interpreter.set_tensor(input_details[0]['index'], data)
|
| 261 |
+
interpreter.invoke()
|
| 262 |
+
|
| 263 |
+
# Benchmark
|
| 264 |
+
start_time = time.time()
|
| 265 |
+
for i in range(num_samples):
|
| 266 |
+
if is_dual:
|
| 267 |
+
for idx, detail in enumerate(input_details):
|
| 268 |
+
if detail['dtype'] == np.uint8:
|
| 269 |
+
data = (test_inputs[idx][i:i+1] * 255).astype(np.uint8)
|
| 270 |
+
else:
|
| 271 |
+
data = test_inputs[idx][i:i+1].astype(np.float32)
|
| 272 |
+
interpreter.set_tensor(detail['index'], data)
|
| 273 |
+
else:
|
| 274 |
+
if input_details[0]['dtype'] == np.uint8:
|
| 275 |
+
data = (test_inputs[i:i+1] * 255).astype(np.uint8)
|
| 276 |
+
else:
|
| 277 |
+
data = test_inputs[i:i+1].astype(np.float32)
|
| 278 |
+
interpreter.set_tensor(input_details[0]['index'], data)
|
| 279 |
+
|
| 280 |
+
interpreter.invoke()
|
| 281 |
+
pred = interpreter.get_tensor(output_details[0]['index'])[0]
|
| 282 |
+
predictions.append(pred)
|
| 283 |
+
|
| 284 |
+
elapsed = time.time() - start_time
|
| 285 |
+
|
| 286 |
+
predictions = np.array(predictions)
|
| 287 |
+
labels = test_labels[:num_samples]
|
| 288 |
+
|
| 289 |
+
# Metrics
|
| 290 |
+
eucl_error = np.mean(np.sqrt(np.sum((predictions - labels) ** 2, axis=-1)))
|
| 291 |
+
|
| 292 |
+
# Screen error in mm (typical phone: 65mm x 140mm)
|
| 293 |
+
diff_mm = (predictions - labels) * np.array([65.0, 140.0])
|
| 294 |
+
screen_error = np.mean(np.sqrt(np.sum(diff_mm ** 2, axis=-1)))
|
| 295 |
+
|
| 296 |
+
# Screen error in cm
|
| 297 |
+
screen_error_cm = screen_error / 10.0
|
| 298 |
+
|
| 299 |
+
avg_inference_ms = (elapsed / num_samples) * 1000
|
| 300 |
+
|
| 301 |
+
print(f" Euclidean error (normalized): {eucl_error:.4f}")
|
| 302 |
+
print(f" Screen error: {screen_error:.1f} mm ({screen_error_cm:.2f} cm)")
|
| 303 |
+
print(f" Average inference time (CPU): {avg_inference_ms:.2f} ms")
|
| 304 |
+
print(f" FPS (CPU): {1000 / avg_inference_ms:.1f}")
|
| 305 |
+
|
| 306 |
+
return {
|
| 307 |
+
'euclidean_error': float(eucl_error),
|
| 308 |
+
'screen_error_mm': float(screen_error),
|
| 309 |
+
'screen_error_cm': float(screen_error_cm),
|
| 310 |
+
'avg_inference_ms': float(avg_inference_ms),
|
| 311 |
+
'fps': float(1000 / avg_inference_ms),
|
| 312 |
+
'num_test_samples': num_samples,
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def main():
|
| 317 |
+
print("="*60)
|
| 318 |
+
print("GazeInception-Lite: Mobile Eye Gaze Estimation")
|
| 319 |
+
print("="*60)
|
| 320 |
+
|
| 321 |
+
# Configuration
|
| 322 |
+
NUM_TRAIN = 50000
|
| 323 |
+
NUM_VAL = 5000
|
| 324 |
+
NUM_TEST = 3000
|
| 325 |
+
EPOCHS_SINGLE = 80
|
| 326 |
+
EPOCHS_DUAL = 80
|
| 327 |
+
BATCH_SIZE = 128
|
| 328 |
+
|
| 329 |
+
OUTPUT_DIR = '/app/output'
|
| 330 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 331 |
+
|
| 332 |
+
# ==========================================
|
| 333 |
+
# Generate synthetic training data
|
| 334 |
+
# ==========================================
|
| 335 |
+
print("\n[1/6] Generating synthetic training data...")
|
| 336 |
+
print(f" Train: {NUM_TRAIN}, Val: {NUM_VAL}, Test: {NUM_TEST}")
|
| 337 |
+
print(f" Augmentations: dark(30%), glasses(25%), lazy_eye(15%), noise(50%)")
|
| 338 |
+
|
| 339 |
+
gen = SyntheticGazeDataGenerator(seed=42)
|
| 340 |
+
|
| 341 |
+
t0 = time.time()
|
| 342 |
+
train_data = gen.generate_dataset(NUM_TRAIN, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15)
|
| 343 |
+
print(f" Train data generated in {time.time()-t0:.1f}s")
|
| 344 |
+
|
| 345 |
+
gen_val = SyntheticGazeDataGenerator(seed=123)
|
| 346 |
+
val_data = gen_val.generate_dataset(NUM_VAL, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15)
|
| 347 |
+
|
| 348 |
+
gen_test = SyntheticGazeDataGenerator(seed=456)
|
| 349 |
+
test_data = gen_test.generate_dataset(NUM_TEST, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15)
|
| 350 |
+
|
| 351 |
+
# Also generate condition-specific test sets for robustness evaluation
|
| 352 |
+
gen_dark = SyntheticGazeDataGenerator(seed=789)
|
| 353 |
+
test_dark = gen_dark.generate_dataset(1000, dark_prob=1.0, with_glasses_prob=0.0, lazy_eye_prob=0.0)
|
| 354 |
+
|
| 355 |
+
gen_glasses = SyntheticGazeDataGenerator(seed=101)
|
| 356 |
+
test_glasses = gen_glasses.generate_dataset(1000, dark_prob=0.0, with_glasses_prob=1.0, lazy_eye_prob=0.0)
|
| 357 |
+
|
| 358 |
+
gen_lazy = SyntheticGazeDataGenerator(seed=202)
|
| 359 |
+
test_lazy = gen_lazy.generate_dataset(1000, dark_prob=0.0, with_glasses_prob=0.0, lazy_eye_prob=1.0)
|
| 360 |
+
|
| 361 |
+
print(f" All data generated. Total time: {time.time()-t0:.1f}s")
|
| 362 |
+
|
| 363 |
+
# ==========================================
|
| 364 |
+
# Train Single-Eye Model
|
| 365 |
+
# ==========================================
|
| 366 |
+
print("\n[2/6] Training single-eye model...")
|
| 367 |
+
single_model, single_history = train_single_eye_model(
|
| 368 |
+
train_data, val_data,
|
| 369 |
+
epochs=EPOCHS_SINGLE, batch_size=BATCH_SIZE,
|
| 370 |
+
output_dir=os.path.join(OUTPUT_DIR, 'single_eye')
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
# ==========================================
|
| 374 |
+
# Train Dual-Eye Model
|
| 375 |
+
# ==========================================
|
| 376 |
+
print("\n[3/6] Training dual-eye model...")
|
| 377 |
+
dual_model, dual_history = train_dual_eye_model(
|
| 378 |
+
train_data, val_data,
|
| 379 |
+
epochs=EPOCHS_DUAL, batch_size=64,
|
| 380 |
+
output_dir=os.path.join(OUTPUT_DIR, 'dual_eye')
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# ==========================================
|
| 384 |
+
# Convert to TFLite
|
| 385 |
+
# ==========================================
|
| 386 |
+
print("\n[4/6] Converting models to TFLite...")
|
| 387 |
+
|
| 388 |
+
# Single eye - float16
|
| 389 |
+
single_tflite_f16_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_single_f16.tflite')
|
| 390 |
+
convert_to_tflite(single_model, single_tflite_f16_path, quantize=False)
|
| 391 |
+
|
| 392 |
+
# Single eye - INT8 quantized
|
| 393 |
+
single_tflite_int8_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_single_int8.tflite')
|
| 394 |
+
all_eyes = np.concatenate([test_data['left_eye'], test_data['right_eye']], axis=0)
|
| 395 |
+
convert_to_tflite(single_model, single_tflite_int8_path, quantize=True, test_data=all_eyes)
|
| 396 |
+
|
| 397 |
+
# Dual eye - float16
|
| 398 |
+
dual_tflite_f16_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_dual_f16.tflite')
|
| 399 |
+
convert_dual_eye_to_tflite(dual_model, dual_tflite_f16_path, quantize=False)
|
| 400 |
+
|
| 401 |
+
# Dual eye - INT8 quantized
|
| 402 |
+
dual_tflite_int8_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_dual_int8.tflite')
|
| 403 |
+
convert_dual_eye_to_tflite(dual_model, dual_tflite_int8_path, quantize=True, test_data=test_data)
|
| 404 |
+
|
| 405 |
+
# ==========================================
|
| 406 |
+
# Evaluate TFLite models
|
| 407 |
+
# ==========================================
|
| 408 |
+
print("\n[5/6] Evaluating TFLite models...")
|
| 409 |
+
|
| 410 |
+
results = {}
|
| 411 |
+
|
| 412 |
+
# Single eye evaluation
|
| 413 |
+
print("\n--- Single Eye Model (Float16) ---")
|
| 414 |
+
results['single_f16'] = evaluate_tflite(
|
| 415 |
+
single_tflite_f16_path, all_eyes[:3000], test_data['gaze']
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
print("\n--- Single Eye Model (INT8) ---")
|
| 419 |
+
results['single_int8'] = evaluate_tflite(
|
| 420 |
+
single_tflite_int8_path, all_eyes[:3000], test_data['gaze']
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# Dual eye evaluation
|
| 424 |
+
print("\n--- Dual Eye Model (Float16) ---")
|
| 425 |
+
dual_inputs = [test_data['left_eye'], test_data['right_eye'], test_data['face']]
|
| 426 |
+
results['dual_f16'] = evaluate_tflite(
|
| 427 |
+
dual_tflite_f16_path, dual_inputs, test_data['gaze'], is_dual=True
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
print("\n--- Dual Eye Model (INT8) ---")
|
| 431 |
+
results['dual_int8'] = evaluate_tflite(
|
| 432 |
+
dual_tflite_int8_path, dual_inputs, test_data['gaze'], is_dual=True
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Condition-specific evaluation (dual model, float16)
|
| 436 |
+
print("\n--- Robustness Evaluation (Dual Eye, Float16) ---")
|
| 437 |
+
print("\n [Dark conditions]")
|
| 438 |
+
dark_inputs = [test_dark['left_eye'], test_dark['right_eye'], test_dark['face']]
|
| 439 |
+
results['dual_f16_dark'] = evaluate_tflite(
|
| 440 |
+
dual_tflite_f16_path, dark_inputs, test_dark['gaze'], is_dual=True
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
print("\n [With glasses]")
|
| 444 |
+
glasses_inputs = [test_glasses['left_eye'], test_glasses['right_eye'], test_glasses['face']]
|
| 445 |
+
results['dual_f16_glasses'] = evaluate_tflite(
|
| 446 |
+
dual_tflite_f16_path, glasses_inputs, test_glasses['gaze'], is_dual=True
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
print("\n [Lazy eye / strabismus]")
|
| 450 |
+
lazy_inputs = [test_lazy['left_eye'], test_lazy['right_eye'], test_lazy['face']]
|
| 451 |
+
results['dual_f16_lazy_eye'] = evaluate_tflite(
|
| 452 |
+
dual_tflite_f16_path, lazy_inputs, test_lazy['gaze'], is_dual=True
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# ==========================================
|
| 456 |
+
# Save results and metadata
|
| 457 |
+
# ==========================================
|
| 458 |
+
print("\n[6/6] Saving results...")
|
| 459 |
+
|
| 460 |
+
# Model card metadata
|
| 461 |
+
metadata = {
|
| 462 |
+
'model_name': 'GazeInception-Lite',
|
| 463 |
+
'task': 'eye-gaze-estimation',
|
| 464 |
+
'description': 'Lightweight TFLite model for mobile eye gaze estimation on phone screens',
|
| 465 |
+
'architecture': {
|
| 466 |
+
'type': 'Gated Inception Network with Coordinate Attention',
|
| 467 |
+
'single_eye_params': int(single_model.count_params()),
|
| 468 |
+
'dual_eye_params': int(dual_model.count_params()),
|
| 469 |
+
'input_size': '64x64x3',
|
| 470 |
+
'features': [
|
| 471 |
+
'Gated Inception blocks (learned branch gating to skip useless compute)',
|
| 472 |
+
'Coordinate Attention for spatial gaze awareness',
|
| 473 |
+
'Depthwise separable convolutions for efficiency',
|
| 474 |
+
'Dual-eye processing with shared weights (handles lazy eye)',
|
| 475 |
+
'Face context branch (head pose proxy)'
|
| 476 |
+
]
|
| 477 |
+
},
|
| 478 |
+
'training': {
|
| 479 |
+
'dataset': 'Synthetic (50K train, 5K val, 3K test)',
|
| 480 |
+
'augmentations': [
|
| 481 |
+
'Dark/low-light conditions (30% probability, 15-50% brightness)',
|
| 482 |
+
'Glasses overlay synthesis (25% probability, 10 frame styles)',
|
| 483 |
+
'Lazy eye/strabismus simulation (15% probability)',
|
| 484 |
+
'CMOS sensor noise (50% probability)',
|
| 485 |
+
'Illumination perturbation (directional light gradients)',
|
| 486 |
+
'Diverse skin tones (12 variations)',
|
| 487 |
+
'Diverse eye colors (7 variations)'
|
| 488 |
+
],
|
| 489 |
+
'optimizer': 'Adam with Cosine Decay LR',
|
| 490 |
+
'initial_lr': 1e-3,
|
| 491 |
+
'loss': 'MSE',
|
| 492 |
+
'epochs': f'{EPOCHS_SINGLE} (single) / {EPOCHS_DUAL} (dual)',
|
| 493 |
+
},
|
| 494 |
+
'tflite_models': {
|
| 495 |
+
'single_eye_f16': {
|
| 496 |
+
'file': 'gaze_inception_lite_single_f16.tflite',
|
| 497 |
+
'size_kb': os.path.getsize(single_tflite_f16_path) / 1024,
|
| 498 |
+
'quantization': 'float16',
|
| 499 |
+
},
|
| 500 |
+
'single_eye_int8': {
|
| 501 |
+
'file': 'gaze_inception_lite_single_int8.tflite',
|
| 502 |
+
'size_kb': os.path.getsize(single_tflite_int8_path) / 1024,
|
| 503 |
+
'quantization': 'int8',
|
| 504 |
+
},
|
| 505 |
+
'dual_eye_f16': {
|
| 506 |
+
'file': 'gaze_inception_lite_dual_f16.tflite',
|
| 507 |
+
'size_kb': os.path.getsize(dual_tflite_f16_path) / 1024,
|
| 508 |
+
'quantization': 'float16',
|
| 509 |
+
},
|
| 510 |
+
'dual_eye_int8': {
|
| 511 |
+
'file': 'gaze_inception_lite_dual_int8.tflite',
|
| 512 |
+
'size_kb': os.path.getsize(dual_tflite_int8_path) / 1024,
|
| 513 |
+
'quantization': 'int8',
|
| 514 |
+
},
|
| 515 |
+
},
|
| 516 |
+
'evaluation_results': results,
|
| 517 |
+
'references': [
|
| 518 |
+
'AGE Framework - arxiv:2603.26945',
|
| 519 |
+
'Gated Compression Layers - arxiv:2303.08970',
|
| 520 |
+
'iTracker / GazeCapture - arxiv:1606.05814',
|
| 521 |
+
'Coordinate Attention - Hou et al. 2021',
|
| 522 |
+
'MobileNetV2 - arxiv:1801.04381',
|
| 523 |
+
]
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
with open(os.path.join(OUTPUT_DIR, 'metadata.json'), 'w') as f:
|
| 527 |
+
json.dump(metadata, f, indent=2)
|
| 528 |
+
|
| 529 |
+
# Save training history
|
| 530 |
+
for name, hist in [('single', single_history), ('dual', dual_history)]:
|
| 531 |
+
hist_dict = {k: [float(v) for v in vals] for k, vals in hist.history.items()}
|
| 532 |
+
with open(os.path.join(OUTPUT_DIR, f'{name}_history.json'), 'w') as f:
|
| 533 |
+
json.dump(hist_dict, f, indent=2)
|
| 534 |
+
|
| 535 |
+
# Print summary
|
| 536 |
+
print("\n" + "="*60)
|
| 537 |
+
print("TRAINING COMPLETE - SUMMARY")
|
| 538 |
+
print("="*60)
|
| 539 |
+
|
| 540 |
+
print(f"\nSingle-Eye Model:")
|
| 541 |
+
print(f" Parameters: {single_model.count_params():,}")
|
| 542 |
+
print(f" F16 TFLite: {os.path.getsize(single_tflite_f16_path)/1024:.1f} KB")
|
| 543 |
+
print(f" INT8 TFLite: {os.path.getsize(single_tflite_int8_path)/1024:.1f} KB")
|
| 544 |
+
if 'single_int8' in results:
|
| 545 |
+
r = results['single_int8']
|
| 546 |
+
print(f" Screen error: {r['screen_error_mm']:.1f} mm")
|
| 547 |
+
print(f" Inference: {r['avg_inference_ms']:.2f} ms ({r['fps']:.0f} FPS)")
|
| 548 |
+
|
| 549 |
+
print(f"\nDual-Eye Model:")
|
| 550 |
+
print(f" Parameters: {dual_model.count_params():,}")
|
| 551 |
+
print(f" F16 TFLite: {os.path.getsize(dual_tflite_f16_path)/1024:.1f} KB")
|
| 552 |
+
print(f" INT8 TFLite: {os.path.getsize(dual_tflite_int8_path)/1024:.1f} KB")
|
| 553 |
+
if 'dual_int8' in results:
|
| 554 |
+
r = results['dual_int8']
|
| 555 |
+
print(f" Screen error: {r['screen_error_mm']:.1f} mm")
|
| 556 |
+
print(f" Inference: {r['avg_inference_ms']:.2f} ms ({r['fps']:.0f} FPS)")
|
| 557 |
+
|
| 558 |
+
print(f"\nRobustness (Dual Eye):")
|
| 559 |
+
for condition in ['dark', 'glasses', 'lazy_eye']:
|
| 560 |
+
key = f'dual_f16_{condition}'
|
| 561 |
+
if key in results:
|
| 562 |
+
r = results[key]
|
| 563 |
+
print(f" {condition}: {r['screen_error_mm']:.1f} mm error")
|
| 564 |
+
|
| 565 |
+
print(f"\nOutput directory: {OUTPUT_DIR}")
|
| 566 |
+
print(f"Files: {os.listdir(OUTPUT_DIR)}")
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
if __name__ == '__main__':
|
| 570 |
+
main()
|